Recognition: no theorem link
Training event-based neural networks with exact gradients via Differentiable ODE Solving in JAX
Pith reviewed 2026-05-15 15:03 UTC · model grok-4.3
The pith
Eventax uses differentiable ODE solvers to train spiking networks with exact gradients for any neuron model defined by differential equations.
A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.
Core claim
We introduce the Eventax framework, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling. Built in JAX, our framework uses Diffrax ODE-solvers to compute gradients that are exact with respect to the forward simulation for any neuron model defined by ODEs. It also provides a simple API where users can specify just the neuron dynamics, spike conditions, and reset rules.
What carries the argument
Differentiable numerical integration of user-defined ODEs with automatic event detection for spikes and resets, which allows automatic differentiation to flow exactly through the solver trajectory.
If this is right
- Any neuron expressible as an ODE can be dropped into a network and trained end-to-end with exact gradients.
- Complex models such as multi-compartment neurons with dendritic spikes become directly trainable without surrogate approximations.
- Both time-to-first-spike and continuous state-based loss functions work uniformly across neuron types.
- New neuron dynamics require only the ODE, threshold, and reset definitions; no new solver code is needed.
Where Pith is reading between the lines
- The same solver-plus-event pattern could be applied to other hybrid continuous-discrete systems such as hybrid automata or physical simulators.
- Because the method inherits JAX's composability, it can be combined with existing JAX libraries for optimization, uncertainty, or hardware mapping.
- Exact gradients remove one source of bias, so performance gaps between surrogate and exact training can now be measured cleanly on richer neuron models.
Load-bearing premise
Numerical event detection and state resets inside the ODE solver preserve exact gradient flow without introducing instabilities or extra discretization error.
What would settle it
Compare the gradients computed by Eventax against finite-difference perturbations or analytical derivatives on a multi-compartment model; mismatch or training divergence would falsify exactness.
Figures
read the original abstract
Existing frameworks for gradient-based training of spiking neural networks face a trade-off: discrete-time methods using surrogate gradients support arbitrary neuron models but introduce gradient bias and constrain spike-time resolution, while continuous-time methods that compute exact gradients require analytical expressions for spike times and state evolution, restricting them to simple neuron types such as Leaky Integrate and Fire (LIF). We introduce the Eventax framework, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling. Built in JAX, our frame-work uses Diffrax ODE-solvers to compute gradients that are exact with respect to the forward simulation for any neuron model defined by ODEs . It also provides a simple API where users can specify just the neuron dynamics, spike conditions, and reset rules. Eventax prioritises modelling flexibility, supporting a wide range of neuron models, loss functions, and network architectures, which can be easily extended. We demonstrate Eventax on multiple benchmarks, including Yin-Yang and MNIST, using diverse neuron models such as Leaky Integrate-and-fire (LIF), Quadratic Integrate-and-fire (QIF), Exponential integrate-and-fire (EIF), Izhikevich and Event-based Gated Recurrent Unit (EGRU) with both time-to-first-spike and state-based loss functions, demonstrating its utility for prototyping and testing event-based architectures trained with exact gradients. We also demonstrate the application of this framework for more complex neuron types by implementing a multi-compartment neuron that uses a model of dendritic spikes in human layer 2/3 cortical Pyramidal neurons for computation. Code available at https://github.com/efficient-scalable-machine-learning/eventax.
Editorial analysis
A structured set of objections, weighed in public.
Referee Report
Summary. The manuscript introduces Eventax, a JAX framework that employs Diffrax differentiable ODE solvers to train event-based neural networks with gradients claimed to be exact with respect to the forward simulation. Users define neuron dynamics, spike conditions, and reset rules for arbitrary ODE-based models; the approach is demonstrated on LIF, QIF, EIF, Izhikevich, and EGRU neurons using Yin-Yang and MNIST benchmarks with time-to-first-spike and state-based losses, plus a multi-compartment pyramidal neuron incorporating dendritic spikes.
Significance. If the exactness claim holds without material numerical artifacts from event detection or resets, the work would meaningfully advance SNN training by removing the surrogate-gradient bias versus analytical-spike-time trade-off and enabling flexible use of complex biophysical models. The simple API, reliance on established libraries (JAX/Diffrax), and open code are practical strengths that support prototyping and reproducibility.
major comments (2)
- [Abstract] Abstract and Methods: The central claim that gradients are 'exact with respect to the forward simulation for any neuron model defined by ODEs' is not supported by any reported verification (finite-difference checks, adjoint-error bounds, or tolerance-sensitivity analysis). This verification is load-bearing because numerical root-finding for events and instantaneous resets can introduce non-differentiable artifacts whose effect on the computed VJP is not quantified.
- [Results] Multi-compartment demonstration: For the human L2/3 pyramidal neuron with dendritic spikes, the manuscript provides no derivation or empirical test showing that Diffrax's event handling and state resets preserve gradient accuracy through the discontinuities; the skeptic concern about accumulated discretization error in non-smooth multi-compartment dynamics therefore remains unaddressed.
minor comments (1)
- [Abstract] Abstract contains the typographical error 'frame-work' (should be 'framework').
Simulated Author's Rebuttal
We thank the referee for their detailed and constructive comments. We address the major comments point-by-point below and will revise the manuscript to include the requested verifications of gradient exactness.
read point-by-point responses
-
Referee: [Abstract] Abstract and Methods: The central claim that gradients are 'exact with respect to the forward simulation for any neuron model defined by ODEs' is not supported by any reported verification (finite-difference checks, adjoint-error bounds, or tolerance-sensitivity analysis). This verification is load-bearing because numerical root-finding for events and instantaneous resets can introduce non-differentiable artifacts whose effect on the computed VJP is not quantified.
Authors: We agree that explicit verification is required to substantiate the exactness claim. In the revised manuscript we will add a new Methods subsection reporting finite-difference checks: for each demonstrated neuron model we will compare the VJP returned by JAX autodiff through Diffrax against central finite differences over small perturbations in parameters and inputs. We will also include a tolerance-sensitivity study that quantifies any residual discrepancy attributable to event root-finding or instantaneous resets. These additions will directly address the concern about potential non-differentiable artifacts. revision: yes
-
Referee: [Results] Multi-compartment demonstration: For the human L2/3 pyramidal neuron with dendritic spikes, the manuscript provides no derivation or empirical test showing that Diffrax's event handling and state resets preserve gradient accuracy through the discontinuities; the skeptic concern about accumulated discretization error in non-smooth multi-compartment dynamics therefore remains unaddressed.
Authors: We acknowledge the absence of targeted validation for the multi-compartment case. In the revision we will add an empirical test in the Results section that compares Eventax-computed gradients for the human L2/3 pyramidal neuron against finite-difference approximations, with particular attention to perturbations crossing dendritic-spike events. We will also include a concise discussion of Diffrax's differentiable event-detection and reset mechanisms (based on its adjoint sensitivity implementation) and how they limit accumulation of discretization error across discontinuities. This will provide the requested empirical support. revision: yes
Circularity Check
No circularity: Eventax relies on external Diffrax/JAX solvers without self-referential reduction
full rationale
The paper's derivation chain consists of wrapping standard differentiable ODE integration (via the external Diffrax library) around user-specified neuron dynamics, spike conditions, and reset rules. No step reduces a claimed result to a fitted parameter, self-definition, or load-bearing self-citation; the exact-gradient property is asserted as a direct consequence of the library's adjoint/VJP implementation rather than being derived from the paper's own equations or prior author work. Benchmarks on LIF/QIF/EIF/Izhikevich/EGRU and multi-compartment models serve as empirical validation, not as inputs that are renamed as outputs. The framework is therefore self-contained against external numerical libraries.
Axiom & Free-Parameter Ledger
axioms (1)
- domain assumption Differentiable numerical ODE solvers (Diffrax) can accurately integrate neuron state equations and detect spike events while preserving exact gradient flow through the simulation.
Reference graph
Works this paper leans on
-
[1]
Snnax-spiking neural networks in jax,
J. Lohoff, J. Finkbeiner, and E. Neftci, “Snnax-spiking neural networks in jax,” in2024 International Conference on Neuro- morphic Systems (ICONS). IEEE, 2024, pp. 251–255
work page 2024
-
[2]
Train- ing spiking neural networks using lessons from deep learning,
J. K. Eshraghian, M. Ward, E. Neftci, X. Wang, G. Lenz, G. Dwivedi, M. Bennamoun, D. S. Jeong, and W. D. Lu, “Train- ing spiking neural networks using lessons from deep learning,” Proceedings of the IEEE, vol. 111, no. 9, pp. 1016–1054, 2023
work page 2023
-
[3]
Long short-term memory and Learning-to-learn in networks of spiking neurons,
G. Bellec, D. Salaj, A. Subramoney, R. Legenstein, and W. Maass, “Long short-term memory and Learning-to-learn in networks of spiking neurons,” inAdvances in Neural In- formation Processing Systems 31, S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, Eds. Curran Associates, Inc., 2018, pp. 787–797
work page 2018
-
[4]
Event-based backpropagation can compute exact gradients for spiking neural networks,
T. C. Wunderlich and C. Pehle, “Event-based backpropagation can compute exact gradients for spiking neural networks,” Scientific Reports, vol. 11, no. 1, Jun. 2021. [Online]. Available: http://dx.doi.org/10.1038/s41598-021-91786-z
-
[5]
Smooth exact gradient descent learning in spiking neural networks,
C. Klos and R.-M. Memmesheimer, “Smooth exact gradient descent learning in spiking neural networks,”Physical Review Letters, vol. 134, no. 2, Jan. 2025. [Online]. Available: http://dx.doi.org/10.1103/PhysRevLett.134.027301
-
[6]
On Neural Differential Equations,
P. Kidger, “On Neural Differential Equations,” Ph.D. disserta- tion, University of Oxford, 2021
work page 2021
-
[7]
Exact gradients for stochastic spiking neural networks driven by rough signals,
C. Holberg and C. Salvi, “Exact gradients for stochastic spiking neural networks driven by rough signals,”Advances in Neural Information Processing Systems, vol. 37, pp. 31907–31939, 2024
work page 2024
-
[8]
Norse - A deep learning library for spiking neural networks,
C. Pehle and J. E. Pedersen, “Norse - A deep learning library for spiking neural networks,” Jan. 2021, documentation: https://norse.ai/docs/. [Online]. Available: https://doi.org/10. 5281/zenodo.4422025
work page 2021
-
[9]
jaxsnn: Event-driven gradient estimation for analog neuromorphic hardware,
E. Müller, M. Althaus, E. Arnold, P. Spilger, C. Pehle, and J. Schemmel, “jaxsnn: Event-driven gradient estimation for analog neuromorphic hardware,” in2024 Neuro Inspired Com- putational Elements Conference (NICE). IEEE, 2024, pp. 1–6
work page 2024
-
[10]
Eventprop training for efficient neuromorphic applications,
T. Shoesmith, J. C. Knight, B. Mészáros, J. Timcheck, and T. Nowotny, “Eventprop training for efficient neuromorphic applications,” 2025. [Online]. Available: https://arxiv.org/abs/ 2503.04341
-
[11]
Fast and energy-efficient neuromorphic deep learning with first-spike times,
J. Göltz, L. Kriener, A. Baumbach, S. Billaudelle, O. Breitwieser, B. Cramer, D. Dold, A. F. Kungl, W. Senn, J. Schemmel, K. Meier, and M. A. Petrovici, “Fast and energy-efficient neuromorphic deep learning with first-spike times,”Nature Machine Intelligence, vol. 3, no. 9, p. 823–835, Sep. 2021. [Online]. Available: http://dx.doi.org/10.1038/s42256-021-00388-x
-
[12]
How spike generation mechanisms determine the neu- ronal response to fluctuating inputs,
N. Fourcaud-Trocmé, D. Hansel, C. van Vreeswijk, and N. Bhre- maud, “How spike generation mechanisms determine the neu- ronal response to fluctuating inputs,”Journal of Neuroscience, vol. 23, no. 37, pp. 11628–11640, 2003
work page 2003
-
[13]
Simple model of spiking neurons,
E. Izhikevich, “Simple model of spiking neurons,”IEEE Trans- actions on Neural Networks, vol. 14, no. 6, pp. 1569–1572, 2003
work page 2003
-
[14]
A.Subramoney,K.K.Nazeer,M.Schöne,C.Mayr,andD.Kap- pel, “Efficient recurrent architectures through activity sparsity and sparse back-propagation through time,” inThe Eleventh International Conference on Learning Representations, 2023
work page 2023
-
[15]
Dendritic action potentials and computation in human layer 2/3 cortical neurons,
A. Gidon, T. A. Zolnik, P. Fidzinski, F. Bolduan, A. Papoutsi, P. Poirazi, M. Holtkamp, I. Vida, and M. E. Larkum, “Dendritic action potentials and computation in human layer 2/3 cortical neurons,”Science, vol. 367, no. 6473, pp. 83–87, Jan. 2020
work page 2020
-
[16]
Voltage oscillations in the barnacle giant muscle fiber,
C. Morris and H. Lecar, “Voltage oscillations in the barnacle giant muscle fiber,”Biophysical Journal, vol. 35, no. 1, pp. 193– 213, Jul. 1981
work page 1981
-
[17]
L. Kriener, J. Göltz, and M. A. Petrovici, “The yin-yang dataset,” inProceedings of the 2022 Annual Neuro-Inspired Computational Elements Conference, 2022, pp. 107–111
work page 2022
-
[18]
MNIST handwritten digit database,
Y. LeCun and C. Cortes, “MNIST handwritten digit database,”
-
[19]
Available: http://yann.lecun.com/exdb/mnist/
[Online]. Available: http://yann.lecun.com/exdb/mnist/
-
[20]
Loss shaping enhances exact gradient learning with eventprop in spiking neural networks,
T. Nowotny, J. P. Turner, and J. C. Knight, “Loss shaping enhances exact gradient learning with eventprop in spiking neural networks,”Neuromorphic Computing and Engineering, vol. 5, no. 1, p. 014001, Jan. 2025. [Online]. Available: http://dx.doi.org/10.1088/2634-4386/ada852
discussion (0)
Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.