pith. machine review for the scientific record. sign in

arxiv: 2603.08146 · v3 · submitted 2026-03-09 · 💻 cs.LG

Recognition: no theorem link

Training event-based neural networks with exact gradients via Differentiable ODE Solving in JAX

Authors on Pith no claims yet

Pith reviewed 2026-05-15 15:03 UTC · model grok-4.3

classification 💻 cs.LG
keywords spiking neural networksevent-based networksdifferentiable ODE solversexact gradientsJAXneuron modelssurrogate gradient trade-offmulti-compartment neurons
0
0 comments X

The pith

Eventax uses differentiable ODE solvers to train spiking networks with exact gradients for any neuron model defined by differential equations.

A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.

The paper shows that prior methods forced a choice between flexible neuron models paired with biased surrogate gradients and exact gradients that only worked for simple analytical neuron types. Eventax removes this choice by wrapping Diffrax numerical ODE solvers with automatic event detection for spikes and state resets inside JAX. A user supplies only the ODE dynamics, spike threshold condition, and reset rule; the framework then integrates the system forward and back-propagates exact gradients through the solver steps. Demonstrations include LIF, QIF, EIF, Izhikevich, EGRU, and a multi-compartment pyramidal-cell model on Yin-Yang and MNIST tasks with both time-to-first-spike and state-based losses.

Core claim

We introduce the Eventax framework, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling. Built in JAX, our framework uses Diffrax ODE-solvers to compute gradients that are exact with respect to the forward simulation for any neuron model defined by ODEs. It also provides a simple API where users can specify just the neuron dynamics, spike conditions, and reset rules.

What carries the argument

Differentiable numerical integration of user-defined ODEs with automatic event detection for spikes and resets, which allows automatic differentiation to flow exactly through the solver trajectory.

If this is right

  • Any neuron expressible as an ODE can be dropped into a network and trained end-to-end with exact gradients.
  • Complex models such as multi-compartment neurons with dendritic spikes become directly trainable without surrogate approximations.
  • Both time-to-first-spike and continuous state-based loss functions work uniformly across neuron types.
  • New neuron dynamics require only the ODE, threshold, and reset definitions; no new solver code is needed.

Where Pith is reading between the lines

These are editorial extensions of the paper, not claims the author makes directly.

  • The same solver-plus-event pattern could be applied to other hybrid continuous-discrete systems such as hybrid automata or physical simulators.
  • Because the method inherits JAX's composability, it can be combined with existing JAX libraries for optimization, uncertainty, or hardware mapping.
  • Exact gradients remove one source of bias, so performance gaps between surrogate and exact training can now be measured cleanly on richer neuron models.

Load-bearing premise

Numerical event detection and state resets inside the ODE solver preserve exact gradient flow without introducing instabilities or extra discretization error.

What would settle it

Compare the gradients computed by Eventax against finite-difference perturbations or analytical derivatives on a multi-compartment model; mismatch or training divergence would falsify exactness.

Figures

Figures reproduced from arXiv: 2603.08146 by Anand Subramoney, David Kappel, Lukas K\"onig, Manuel Kuhn.

Figure 1
Figure 1. Figure 1: The NeuronModel interface. Users define custom neuron models by implementing: initial state, dynamics, spike condition, input spike handling, and post-spike reset. event-based recurrent units that are machine-learning ori￾ented such as the EGRU, while remaining straightforward to extend with custom models. The MultiNeuronModel class enables heterogeneous networks out of previously de￾fined NeuronModel in w… view at source ↗
Figure 2
Figure 2. Figure 2: Schematic of the multi-compartment neuron model and [PITH_FULL_IMAGE:figures/full_fig_p003_2.png] view at source ↗
Figure 4
Figure 4. Figure 4: Usage example of EventPropJax: training a LIF network [PITH_FULL_IMAGE:figures/full_fig_p004_4.png] view at source ↗
Figure 5
Figure 5. Figure 5: ). This construction implements an XOR task over the two temporally separated inputs: trials where the two input populations differ are labelled as class 1, and trials where they are equal as class 0. For this task, we generate a set of 5000 training samples as well as a test and validation set containing 1000 samples each. The recurrent layer consists of 32 fully connected neu￾rons without self-connection… view at source ↗
Figure 6
Figure 6. Figure 6: Throughput of QIF models on the TTFS Yin–Yang task (NVIDIA H100). Three models were trained for various Euler step-size [PITH_FULL_IMAGE:figures/full_fig_p008_6.png] view at source ↗
read the original abstract

Existing frameworks for gradient-based training of spiking neural networks face a trade-off: discrete-time methods using surrogate gradients support arbitrary neuron models but introduce gradient bias and constrain spike-time resolution, while continuous-time methods that compute exact gradients require analytical expressions for spike times and state evolution, restricting them to simple neuron types such as Leaky Integrate and Fire (LIF). We introduce the Eventax framework, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling. Built in JAX, our frame-work uses Diffrax ODE-solvers to compute gradients that are exact with respect to the forward simulation for any neuron model defined by ODEs . It also provides a simple API where users can specify just the neuron dynamics, spike conditions, and reset rules. Eventax prioritises modelling flexibility, supporting a wide range of neuron models, loss functions, and network architectures, which can be easily extended. We demonstrate Eventax on multiple benchmarks, including Yin-Yang and MNIST, using diverse neuron models such as Leaky Integrate-and-fire (LIF), Quadratic Integrate-and-fire (QIF), Exponential integrate-and-fire (EIF), Izhikevich and Event-based Gated Recurrent Unit (EGRU) with both time-to-first-spike and state-based loss functions, demonstrating its utility for prototyping and testing event-based architectures trained with exact gradients. We also demonstrate the application of this framework for more complex neuron types by implementing a multi-compartment neuron that uses a model of dendritic spikes in human layer 2/3 cortical Pyramidal neurons for computation. Code available at https://github.com/efficient-scalable-machine-learning/eventax.

Editorial analysis

A structured set of objections, weighed in public.

Desk editor's note, referee report, simulated authors' rebuttal, and a circularity audit. Tearing a paper down is the easy half of reading it; the pith above is the substance, this is the friction.

Referee Report

2 major / 1 minor

Summary. The manuscript introduces Eventax, a JAX framework that employs Diffrax differentiable ODE solvers to train event-based neural networks with gradients claimed to be exact with respect to the forward simulation. Users define neuron dynamics, spike conditions, and reset rules for arbitrary ODE-based models; the approach is demonstrated on LIF, QIF, EIF, Izhikevich, and EGRU neurons using Yin-Yang and MNIST benchmarks with time-to-first-spike and state-based losses, plus a multi-compartment pyramidal neuron incorporating dendritic spikes.

Significance. If the exactness claim holds without material numerical artifacts from event detection or resets, the work would meaningfully advance SNN training by removing the surrogate-gradient bias versus analytical-spike-time trade-off and enabling flexible use of complex biophysical models. The simple API, reliance on established libraries (JAX/Diffrax), and open code are practical strengths that support prototyping and reproducibility.

major comments (2)
  1. [Abstract] Abstract and Methods: The central claim that gradients are 'exact with respect to the forward simulation for any neuron model defined by ODEs' is not supported by any reported verification (finite-difference checks, adjoint-error bounds, or tolerance-sensitivity analysis). This verification is load-bearing because numerical root-finding for events and instantaneous resets can introduce non-differentiable artifacts whose effect on the computed VJP is not quantified.
  2. [Results] Multi-compartment demonstration: For the human L2/3 pyramidal neuron with dendritic spikes, the manuscript provides no derivation or empirical test showing that Diffrax's event handling and state resets preserve gradient accuracy through the discontinuities; the skeptic concern about accumulated discretization error in non-smooth multi-compartment dynamics therefore remains unaddressed.
minor comments (1)
  1. [Abstract] Abstract contains the typographical error 'frame-work' (should be 'framework').

Simulated Author's Rebuttal

2 responses · 0 unresolved

We thank the referee for their detailed and constructive comments. We address the major comments point-by-point below and will revise the manuscript to include the requested verifications of gradient exactness.

read point-by-point responses
  1. Referee: [Abstract] Abstract and Methods: The central claim that gradients are 'exact with respect to the forward simulation for any neuron model defined by ODEs' is not supported by any reported verification (finite-difference checks, adjoint-error bounds, or tolerance-sensitivity analysis). This verification is load-bearing because numerical root-finding for events and instantaneous resets can introduce non-differentiable artifacts whose effect on the computed VJP is not quantified.

    Authors: We agree that explicit verification is required to substantiate the exactness claim. In the revised manuscript we will add a new Methods subsection reporting finite-difference checks: for each demonstrated neuron model we will compare the VJP returned by JAX autodiff through Diffrax against central finite differences over small perturbations in parameters and inputs. We will also include a tolerance-sensitivity study that quantifies any residual discrepancy attributable to event root-finding or instantaneous resets. These additions will directly address the concern about potential non-differentiable artifacts. revision: yes

  2. Referee: [Results] Multi-compartment demonstration: For the human L2/3 pyramidal neuron with dendritic spikes, the manuscript provides no derivation or empirical test showing that Diffrax's event handling and state resets preserve gradient accuracy through the discontinuities; the skeptic concern about accumulated discretization error in non-smooth multi-compartment dynamics therefore remains unaddressed.

    Authors: We acknowledge the absence of targeted validation for the multi-compartment case. In the revision we will add an empirical test in the Results section that compares Eventax-computed gradients for the human L2/3 pyramidal neuron against finite-difference approximations, with particular attention to perturbations crossing dendritic-spike events. We will also include a concise discussion of Diffrax's differentiable event-detection and reset mechanisms (based on its adjoint sensitivity implementation) and how they limit accumulation of discretization error across discontinuities. This will provide the requested empirical support. revision: yes

Circularity Check

0 steps flagged

No circularity: Eventax relies on external Diffrax/JAX solvers without self-referential reduction

full rationale

The paper's derivation chain consists of wrapping standard differentiable ODE integration (via the external Diffrax library) around user-specified neuron dynamics, spike conditions, and reset rules. No step reduces a claimed result to a fitted parameter, self-definition, or load-bearing self-citation; the exact-gradient property is asserted as a direct consequence of the library's adjoint/VJP implementation rather than being derived from the paper's own equations or prior author work. Benchmarks on LIF/QIF/EIF/Izhikevich/EGRU and multi-compartment models serve as empirical validation, not as inputs that are renamed as outputs. The framework is therefore self-contained against external numerical libraries.

Axiom & Free-Parameter Ledger

0 free parameters · 1 axioms · 0 invented entities

The central claim rests on the assumption that standard numerical ODE integration and event detection produce gradients that are exact with respect to the discrete-event forward pass; no free parameters or new entities are introduced beyond the choice of neuron ODEs and loss functions.

axioms (1)
  • domain assumption Differentiable numerical ODE solvers (Diffrax) can accurately integrate neuron state equations and detect spike events while preserving exact gradient flow through the simulation.
    Invoked when stating that gradients are exact with respect to the forward simulation for any ODE-defined neuron model.

pith-pipeline@v0.9.0 · 5606 in / 1197 out tokens · 54002 ms · 2026-05-15T15:03:35.189351+00:00 · methodology

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Reference graph

Works this paper leans on

20 extracted references · 20 canonical work pages

  1. [1]

    Snnax-spiking neural networks in jax,

    J. Lohoff, J. Finkbeiner, and E. Neftci, “Snnax-spiking neural networks in jax,” in2024 International Conference on Neuro- morphic Systems (ICONS). IEEE, 2024, pp. 251–255

  2. [2]

    Train- ing spiking neural networks using lessons from deep learning,

    J. K. Eshraghian, M. Ward, E. Neftci, X. Wang, G. Lenz, G. Dwivedi, M. Bennamoun, D. S. Jeong, and W. D. Lu, “Train- ing spiking neural networks using lessons from deep learning,” Proceedings of the IEEE, vol. 111, no. 9, pp. 1016–1054, 2023

  3. [3]

    Long short-term memory and Learning-to-learn in networks of spiking neurons,

    G. Bellec, D. Salaj, A. Subramoney, R. Legenstein, and W. Maass, “Long short-term memory and Learning-to-learn in networks of spiking neurons,” inAdvances in Neural In- formation Processing Systems 31, S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, Eds. Curran Associates, Inc., 2018, pp. 787–797

  4. [4]

    Event-based backpropagation can compute exact gradients for spiking neural networks,

    T. C. Wunderlich and C. Pehle, “Event-based backpropagation can compute exact gradients for spiking neural networks,” Scientific Reports, vol. 11, no. 1, Jun. 2021. [Online]. Available: http://dx.doi.org/10.1038/s41598-021-91786-z

  5. [5]

    Smooth exact gradient descent learning in spiking neural networks,

    C. Klos and R.-M. Memmesheimer, “Smooth exact gradient descent learning in spiking neural networks,”Physical Review Letters, vol. 134, no. 2, Jan. 2025. [Online]. Available: http://dx.doi.org/10.1103/PhysRevLett.134.027301

  6. [6]

    On Neural Differential Equations,

    P. Kidger, “On Neural Differential Equations,” Ph.D. disserta- tion, University of Oxford, 2021

  7. [7]

    Exact gradients for stochastic spiking neural networks driven by rough signals,

    C. Holberg and C. Salvi, “Exact gradients for stochastic spiking neural networks driven by rough signals,”Advances in Neural Information Processing Systems, vol. 37, pp. 31907–31939, 2024

  8. [8]

    Norse - A deep learning library for spiking neural networks,

    C. Pehle and J. E. Pedersen, “Norse - A deep learning library for spiking neural networks,” Jan. 2021, documentation: https://norse.ai/docs/. [Online]. Available: https://doi.org/10. 5281/zenodo.4422025

  9. [9]

    jaxsnn: Event-driven gradient estimation for analog neuromorphic hardware,

    E. Müller, M. Althaus, E. Arnold, P. Spilger, C. Pehle, and J. Schemmel, “jaxsnn: Event-driven gradient estimation for analog neuromorphic hardware,” in2024 Neuro Inspired Com- putational Elements Conference (NICE). IEEE, 2024, pp. 1–6

  10. [10]

    Eventprop training for efficient neuromorphic applications,

    T. Shoesmith, J. C. Knight, B. Mészáros, J. Timcheck, and T. Nowotny, “Eventprop training for efficient neuromorphic applications,” 2025. [Online]. Available: https://arxiv.org/abs/ 2503.04341

  11. [11]

    Fast and energy-efficient neuromorphic deep learning with first-spike times,

    J. Göltz, L. Kriener, A. Baumbach, S. Billaudelle, O. Breitwieser, B. Cramer, D. Dold, A. F. Kungl, W. Senn, J. Schemmel, K. Meier, and M. A. Petrovici, “Fast and energy-efficient neuromorphic deep learning with first-spike times,”Nature Machine Intelligence, vol. 3, no. 9, p. 823–835, Sep. 2021. [Online]. Available: http://dx.doi.org/10.1038/s42256-021-00388-x

  12. [12]

    How spike generation mechanisms determine the neu- ronal response to fluctuating inputs,

    N. Fourcaud-Trocmé, D. Hansel, C. van Vreeswijk, and N. Bhre- maud, “How spike generation mechanisms determine the neu- ronal response to fluctuating inputs,”Journal of Neuroscience, vol. 23, no. 37, pp. 11628–11640, 2003

  13. [13]

    Simple model of spiking neurons,

    E. Izhikevich, “Simple model of spiking neurons,”IEEE Trans- actions on Neural Networks, vol. 14, no. 6, pp. 1569–1572, 2003

  14. [14]

    Efficient recurrent architectures through activity sparsity and sparse back-propagation through time,

    A.Subramoney,K.K.Nazeer,M.Schöne,C.Mayr,andD.Kap- pel, “Efficient recurrent architectures through activity sparsity and sparse back-propagation through time,” inThe Eleventh International Conference on Learning Representations, 2023

  15. [15]

    Dendritic action potentials and computation in human layer 2/3 cortical neurons,

    A. Gidon, T. A. Zolnik, P. Fidzinski, F. Bolduan, A. Papoutsi, P. Poirazi, M. Holtkamp, I. Vida, and M. E. Larkum, “Dendritic action potentials and computation in human layer 2/3 cortical neurons,”Science, vol. 367, no. 6473, pp. 83–87, Jan. 2020

  16. [16]

    Voltage oscillations in the barnacle giant muscle fiber,

    C. Morris and H. Lecar, “Voltage oscillations in the barnacle giant muscle fiber,”Biophysical Journal, vol. 35, no. 1, pp. 193– 213, Jul. 1981

  17. [17]

    The yin-yang dataset,

    L. Kriener, J. Göltz, and M. A. Petrovici, “The yin-yang dataset,” inProceedings of the 2022 Annual Neuro-Inspired Computational Elements Conference, 2022, pp. 107–111

  18. [18]

    MNIST handwritten digit database,

    Y. LeCun and C. Cortes, “MNIST handwritten digit database,”

  19. [19]

    Available: http://yann.lecun.com/exdb/mnist/

    [Online]. Available: http://yann.lecun.com/exdb/mnist/

  20. [20]

    Loss shaping enhances exact gradient learning with eventprop in spiking neural networks,

    T. Nowotny, J. P. Turner, and J. C. Knight, “Loss shaping enhances exact gradient learning with eventprop in spiking neural networks,”Neuromorphic Computing and Engineering, vol. 5, no. 1, p. 014001, Jan. 2025. [Online]. Available: http://dx.doi.org/10.1088/2634-4386/ada852