Mixed Precision Training of Neural ODEs
Pith reviewed 2026-05-18 03:45 UTC · model grok-4.3
The pith
Neural ODEs can train with low-precision velocity evaluations and state storage while using dynamic adjoint scaling and high-precision accumulation to keep accuracy intact.
A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.
Core claim
By evaluating the neural-network velocity field and storing intermediate states in low precision while applying custom dynamic adjoint scaling and accumulating both the solution trajectory and its gradients in higher precision, explicit ODE solvers for Neural ODEs can be made numerically reliable, yielding approximately 50 percent memory reduction and up to 2x speedup without loss of training accuracy on image-classification and generative-model tasks.
What carries the argument
Custom dynamic adjoint scaling together with high-precision accumulation of solutions and gradients, which protects against roundoff while low-precision arithmetic is used for the velocity network and state storage.
If this is right
- Memory footprint for Neural ODE training drops by about 50 percent.
- Training runs up to twice as fast while accuracy stays comparable to single precision.
- The scheme works for explicit ODE solvers across image classification and generative modeling tasks.
- An extendable PyTorch package supplies the implementation as a drop-in replacement for existing Neural ODE code.
Where Pith is reading between the lines
- The same scaling-plus-accumulation idea could be tested on other continuous-depth architectures that rely on adjoint methods.
- Further reduction in precision might become feasible if the dynamic scaling rule is tuned per layer or per time interval.
- Hardware with native low-precision tensor cores could see even larger speed gains once the adjoint scaling is implemented directly in those formats.
Load-bearing premise
The combination of dynamic adjoint scaling and high-precision accumulation is enough to stop roundoff errors and instabilities from appearing when velocity evaluations and intermediate states are computed and stored in low precision.
What would settle it
Training a Neural ODE on one of the paper's challenging test cases with the mixed-precision scheme produces noticeably lower accuracy or divergence compared with the identical model trained in single precision.
read the original abstract
Exploiting low-precision computations has become a standard strategy in deep learning to address the growing computational costs imposed by ever larger models and datasets. However, naively performing all computations in low precision can lead to roundoff errors and instabilities. Therefore, mixed precision training schemes usually store the weights in high precision and use low-precision computations only for whitelisted operations. Despite their success, these principles are currently not reliable for training continuous-time architectures such as neural ordinary differential equations (Neural ODEs). This paper presents a mixed precision training framework for neural ODEs consisting of explicit ODE solvers and a custom backpropagation scheme and shows their effectiveness in a range of learning tasks. Our scheme uses low-precision computations for evaluating the velocity, parameterized by the neural network, and for storing intermediate states, while numerical reliability is provided by custom dynamic adjoint scaling and by accumulating the solution and gradients in higher precision. These contributions address two key challenges in training neural ODEs: the computational cost of repeated network evaluations and the growth of memory requirements with the number of time steps or layers. Along with the paper we publish our extendable, open-source PyTorch package \texttt{rampde}, whose syntax resembles that of leading packages to provide a drop-in replacement in existing codes. We demonstrate the reliability and effectiveness of our scheme using challenging test cases and on neural ODE applications in image classification and generative models, achieving approximately 50\% memory reduction and up to 2x speedup while maintaining accuracy comparable to single-precision training.
Editorial analysis
A structured set of objections, weighed in public.
Referee Report
Summary. The manuscript proposes a mixed-precision training framework for Neural ODEs with explicit solvers. Low-precision arithmetic is used for neural-network velocity evaluations and intermediate state storage, while a custom dynamic adjoint scaling is applied in the backward pass and solutions/gradients are accumulated in higher precision to preserve numerical reliability. The authors report approximately 50% memory reduction and up to 2x speedup on image-classification and generative-model tasks while maintaining accuracy comparable to single-precision baselines, and release an open-source PyTorch package (rampde) as a drop-in replacement.
Significance. If the numerical-stability claims are substantiated, the work would meaningfully advance efficient training of continuous-depth models by directly addressing the repeated network evaluations and memory growth that limit Neural ODE scalability. The open-source package is a concrete strength that could accelerate adoption. The empirical results on non-trivial applications lend practical relevance, but the absence of forward-pass error analysis makes the reliability guarantee conditional rather than general.
major comments (2)
- [Method (forward-pass description)] Method section describing the forward integration and dynamic adjoint scaling: no error bound, stability analysis, or propagation estimate is supplied for round-off accumulation when low-precision velocities and states are used inside explicit solvers; the scaling is presented only as a backward-pass device, leaving open whether forward perturbations degrade gradients on non-linear or long-horizon dynamics even with high-precision final accumulation.
- [Experiments] Experiments section (results on test cases and applications): the reported memory and speedup figures are not accompanied by ablations on precision choices or scaling hyperparameters, nor by verification that post-hoc adjustments preserved the central accuracy claim; this weakens the cross-task generality of the 50% memory / 2x speedup statement.
minor comments (2)
- [Method] Notation for the dynamic scaling factor could be introduced earlier and used consistently when contrasting forward and adjoint passes.
- [Discussion] A short discussion of how the scheme interacts with common explicit solvers (e.g., RK4 versus Euler) would clarify the scope of the reported speed-ups.
Simulated Author's Rebuttal
We thank the referee for their constructive comments on our manuscript. We address the major comments point by point below, indicating where revisions will be made to improve the paper.
read point-by-point responses
-
Referee: [Method (forward-pass description)] Method section describing the forward integration and dynamic adjoint scaling: no error bound, stability analysis, or propagation estimate is supplied for round-off accumulation when low-precision velocities and states are used inside explicit solvers; the scaling is presented only as a backward-pass device, leaving open whether forward perturbations degrade gradients on non-linear or long-horizon dynamics even with high-precision final accumulation.
Authors: We agree that a formal error analysis for the forward pass would enhance the theoretical grounding of our mixed-precision scheme. The manuscript emphasizes empirical evidence from challenging test cases and applications, where the high-precision accumulation of solutions and gradients helps mitigate round-off errors. The dynamic adjoint scaling is indeed focused on stabilizing the backward pass. To address this, we will include a new subsection discussing error propagation estimates based on standard numerical ODE theory adapted to our low-precision setting, along with additional forward-pass accuracy metrics in the revised manuscript. revision: yes
-
Referee: [Experiments] Experiments section (results on test cases and applications): the reported memory and speedup figures are not accompanied by ablations on precision choices or scaling hyperparameters, nor by verification that post-hoc adjustments preserved the central accuracy claim; this weakens the cross-task generality of the 50% memory / 2x speedup statement.
Authors: We acknowledge the value of ablations to support the generality of our results. The original experiments demonstrate performance on image classification and generative models with the proposed mixed-precision approach. We will add ablations varying the precision levels and scaling factors, and include checks confirming that accuracy remains comparable to single-precision baselines across these variations in the revised version. revision: yes
Circularity Check
No significant circularity; method components are independent engineering choices validated empirically
full rationale
The paper proposes a mixed-precision framework for Neural ODEs using low-precision velocity evaluations and state storage, with custom dynamic adjoint scaling and high-precision accumulation for stability. These elements are introduced as novel contributions rather than derived from or equivalent to the performance metrics they enable. No self-definitional loops, fitted parameters renamed as predictions, or load-bearing self-citations appear in the described scheme. The effectiveness is shown through experiments on classification and generative tasks, keeping the central claims self-contained against external benchmarks rather than reducing to the paper's own inputs by construction.
Axiom & Free-Parameter Ledger
axioms (1)
- domain assumption Low-precision computations for velocity evaluation and intermediate states remain numerically stable when paired with dynamic adjoint scaling and high-precision accumulation of solutions and gradients.
Lean theorems connected to this paper
-
IndisputableMonolith/Cost/FunctionalEquation.leanwashburn_uniqueness_aczel unclear?
unclearRelation between the paper passage and the cited Recognition theorem.
Our scheme uses low-precision computations for evaluating the velocity, parameterized by the neural network, and for storing intermediate states, while numerical reliability is provided by custom dynamic adjoint scaling and by accumulating the solution and gradients in higher precision
-
IndisputableMonolith/Foundation/AlphaCoordinateFixation.leanJ_uniquely_calibrated_via_higher_derivative unclear?
unclearRelation between the paper passage and the cited Recognition theorem.
We show that roundoff errors remain in the order of the unit roundoff of the low precision and do not grow uncontrollably with the number of time steps
What do these tags mean?
- matches
- The paper's claim is directly supported by a theorem in the formal canon.
- supports
- The theorem supports part of the paper's argument, but the paper may add assumptions or extra steps.
- extends
- The paper goes beyond the formal theorem; the theorem is a base layer rather than the whole result.
- uses
- The paper appears to rely on the theorem as machinery.
- contradicts
- The paper's claim conflicts with a theorem or certificate in the canon.
- unclear
- Pith found a possible connection, but the passage is too broad, indirect, or ambiguous to say the theorem truly supports the claim.
discussion (0)
Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.