pith. sign in

arxiv: 2510.23498 · v2 · submitted 2025-10-27 · 💻 cs.LG · cs.AI· cs.NA· math.NA

Mixed Precision Training of Neural ODEs

Pith reviewed 2026-05-18 03:45 UTC · model grok-4.3

classification 💻 cs.LG cs.AIcs.NAmath.NA
keywords mixed precision trainingNeural ODEsadjoint sensitivitylow-precision arithmeticmemory reductionODE solversdynamic scaling
0
0 comments X

The pith

Neural ODEs can train with low-precision velocity evaluations and state storage while using dynamic adjoint scaling and high-precision accumulation to keep accuracy intact.

A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.

The paper presents a mixed precision training framework for neural ODEs that performs velocity evaluations and stores intermediate states in low precision. Numerical stability comes from a custom dynamic adjoint scaling step together with high-precision accumulation of the solution and gradients. This targets the repeated network evaluations and growing memory cost that arise when integrating explicit ODE solvers over many time steps. A reader would care because the scheme delivers roughly 50 percent memory savings and up to twofold speedup on tasks such as image classification and generative modeling while matching single-precision accuracy. The framework is packaged as an open-source PyTorch library that can drop into existing Neural ODE code.

Core claim

By evaluating the neural-network velocity field and storing intermediate states in low precision while applying custom dynamic adjoint scaling and accumulating both the solution trajectory and its gradients in higher precision, explicit ODE solvers for Neural ODEs can be made numerically reliable, yielding approximately 50 percent memory reduction and up to 2x speedup without loss of training accuracy on image-classification and generative-model tasks.

What carries the argument

Custom dynamic adjoint scaling together with high-precision accumulation of solutions and gradients, which protects against roundoff while low-precision arithmetic is used for the velocity network and state storage.

If this is right

  • Memory footprint for Neural ODE training drops by about 50 percent.
  • Training runs up to twice as fast while accuracy stays comparable to single precision.
  • The scheme works for explicit ODE solvers across image classification and generative modeling tasks.
  • An extendable PyTorch package supplies the implementation as a drop-in replacement for existing Neural ODE code.

Where Pith is reading between the lines

These are editorial extensions of the paper, not claims the author makes directly.

  • The same scaling-plus-accumulation idea could be tested on other continuous-depth architectures that rely on adjoint methods.
  • Further reduction in precision might become feasible if the dynamic scaling rule is tuned per layer or per time interval.
  • Hardware with native low-precision tensor cores could see even larger speed gains once the adjoint scaling is implemented directly in those formats.

Load-bearing premise

The combination of dynamic adjoint scaling and high-precision accumulation is enough to stop roundoff errors and instabilities from appearing when velocity evaluations and intermediate states are computed and stored in low precision.

What would settle it

Training a Neural ODE on one of the paper's challenging test cases with the mixed-precision scheme produces noticeably lower accuracy or divergence compared with the identical model trained in single precision.

read the original abstract

Exploiting low-precision computations has become a standard strategy in deep learning to address the growing computational costs imposed by ever larger models and datasets. However, naively performing all computations in low precision can lead to roundoff errors and instabilities. Therefore, mixed precision training schemes usually store the weights in high precision and use low-precision computations only for whitelisted operations. Despite their success, these principles are currently not reliable for training continuous-time architectures such as neural ordinary differential equations (Neural ODEs). This paper presents a mixed precision training framework for neural ODEs consisting of explicit ODE solvers and a custom backpropagation scheme and shows their effectiveness in a range of learning tasks. Our scheme uses low-precision computations for evaluating the velocity, parameterized by the neural network, and for storing intermediate states, while numerical reliability is provided by custom dynamic adjoint scaling and by accumulating the solution and gradients in higher precision. These contributions address two key challenges in training neural ODEs: the computational cost of repeated network evaluations and the growth of memory requirements with the number of time steps or layers. Along with the paper we publish our extendable, open-source PyTorch package \texttt{rampde}, whose syntax resembles that of leading packages to provide a drop-in replacement in existing codes. We demonstrate the reliability and effectiveness of our scheme using challenging test cases and on neural ODE applications in image classification and generative models, achieving approximately 50\% memory reduction and up to 2x speedup while maintaining accuracy comparable to single-precision training.

Editorial analysis

A structured set of objections, weighed in public.

Desk editor's note, referee report, simulated authors' rebuttal, and a circularity audit. Tearing a paper down is the easy half of reading it; the pith above is the substance, this is the friction.

Referee Report

2 major / 2 minor

Summary. The manuscript proposes a mixed-precision training framework for Neural ODEs with explicit solvers. Low-precision arithmetic is used for neural-network velocity evaluations and intermediate state storage, while a custom dynamic adjoint scaling is applied in the backward pass and solutions/gradients are accumulated in higher precision to preserve numerical reliability. The authors report approximately 50% memory reduction and up to 2x speedup on image-classification and generative-model tasks while maintaining accuracy comparable to single-precision baselines, and release an open-source PyTorch package (rampde) as a drop-in replacement.

Significance. If the numerical-stability claims are substantiated, the work would meaningfully advance efficient training of continuous-depth models by directly addressing the repeated network evaluations and memory growth that limit Neural ODE scalability. The open-source package is a concrete strength that could accelerate adoption. The empirical results on non-trivial applications lend practical relevance, but the absence of forward-pass error analysis makes the reliability guarantee conditional rather than general.

major comments (2)
  1. [Method (forward-pass description)] Method section describing the forward integration and dynamic adjoint scaling: no error bound, stability analysis, or propagation estimate is supplied for round-off accumulation when low-precision velocities and states are used inside explicit solvers; the scaling is presented only as a backward-pass device, leaving open whether forward perturbations degrade gradients on non-linear or long-horizon dynamics even with high-precision final accumulation.
  2. [Experiments] Experiments section (results on test cases and applications): the reported memory and speedup figures are not accompanied by ablations on precision choices or scaling hyperparameters, nor by verification that post-hoc adjustments preserved the central accuracy claim; this weakens the cross-task generality of the 50% memory / 2x speedup statement.
minor comments (2)
  1. [Method] Notation for the dynamic scaling factor could be introduced earlier and used consistently when contrasting forward and adjoint passes.
  2. [Discussion] A short discussion of how the scheme interacts with common explicit solvers (e.g., RK4 versus Euler) would clarify the scope of the reported speed-ups.

Simulated Author's Rebuttal

2 responses · 0 unresolved

We thank the referee for their constructive comments on our manuscript. We address the major comments point by point below, indicating where revisions will be made to improve the paper.

read point-by-point responses
  1. Referee: [Method (forward-pass description)] Method section describing the forward integration and dynamic adjoint scaling: no error bound, stability analysis, or propagation estimate is supplied for round-off accumulation when low-precision velocities and states are used inside explicit solvers; the scaling is presented only as a backward-pass device, leaving open whether forward perturbations degrade gradients on non-linear or long-horizon dynamics even with high-precision final accumulation.

    Authors: We agree that a formal error analysis for the forward pass would enhance the theoretical grounding of our mixed-precision scheme. The manuscript emphasizes empirical evidence from challenging test cases and applications, where the high-precision accumulation of solutions and gradients helps mitigate round-off errors. The dynamic adjoint scaling is indeed focused on stabilizing the backward pass. To address this, we will include a new subsection discussing error propagation estimates based on standard numerical ODE theory adapted to our low-precision setting, along with additional forward-pass accuracy metrics in the revised manuscript. revision: yes

  2. Referee: [Experiments] Experiments section (results on test cases and applications): the reported memory and speedup figures are not accompanied by ablations on precision choices or scaling hyperparameters, nor by verification that post-hoc adjustments preserved the central accuracy claim; this weakens the cross-task generality of the 50% memory / 2x speedup statement.

    Authors: We acknowledge the value of ablations to support the generality of our results. The original experiments demonstrate performance on image classification and generative models with the proposed mixed-precision approach. We will add ablations varying the precision levels and scaling factors, and include checks confirming that accuracy remains comparable to single-precision baselines across these variations in the revised version. revision: yes

Circularity Check

0 steps flagged

No significant circularity; method components are independent engineering choices validated empirically

full rationale

The paper proposes a mixed-precision framework for Neural ODEs using low-precision velocity evaluations and state storage, with custom dynamic adjoint scaling and high-precision accumulation for stability. These elements are introduced as novel contributions rather than derived from or equivalent to the performance metrics they enable. No self-definitional loops, fitted parameters renamed as predictions, or load-bearing self-citations appear in the described scheme. The effectiveness is shown through experiments on classification and generative tasks, keeping the central claims self-contained against external benchmarks rather than reducing to the paper's own inputs by construction.

Axiom & Free-Parameter Ledger

0 free parameters · 1 axioms · 0 invented entities

The central claim rests on the engineering assumption that selective low-precision plus custom scaling suffices for stability; no free parameters or new entities are introduced in the abstract description.

axioms (1)
  • domain assumption Low-precision computations for velocity evaluation and intermediate states remain numerically stable when paired with dynamic adjoint scaling and high-precision accumulation of solutions and gradients.
    This premise is invoked to justify the mixed-precision scheme for explicit ODE solvers and backpropagation.

pith-pipeline@v0.9.0 · 5810 in / 1308 out tokens · 26899 ms · 2026-05-18T03:45:06.759842+00:00 · methodology

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Lean theorems connected to this paper

Citations machine-checked in the Pith Canon. Every link opens the source theorem in the public Lean library.

What do these tags mean?
matches
The paper's claim is directly supported by a theorem in the formal canon.
supports
The theorem supports part of the paper's argument, but the paper may add assumptions or extra steps.
extends
The paper goes beyond the formal theorem; the theorem is a base layer rather than the whole result.
uses
The paper appears to rely on the theorem as machinery.
contradicts
The paper's claim conflicts with a theorem or certificate in the canon.
unclear
Pith found a possible connection, but the passage is too broad, indirect, or ambiguous to say the theorem truly supports the claim.