pith. sign in

arxiv: 2111.00254 · v1 · pith:6HLLENLNnew · submitted 2021-10-30 · 💻 cs.LG · cs.PL

Equinox: neural networks in JAX via callable PyTrees and filtered transformations

classification 💻 cs.LG cs.PL
keywords equinoxfunctionfunctionsparameterisedfunctionalneuralpytreestransformations
0
0 comments X
read the original abstract

JAX and PyTorch are two popular Python autodifferentiation frameworks. JAX is based around pure functions and functional programming. PyTorch has popularised the use of an object-oriented (OO) class-based syntax for defining parameterised functions, such as neural networks. That this seems like a fundamental difference means current libraries for building parameterised functions in JAX have either rejected the OO approach entirely (Stax) or have introduced OO-to-functional transformations, multiple new abstractions, and been limited in the extent to which they integrate with JAX (Flax, Haiku, Objax). Either way this OO/functional difference has been a source of tension. Here, we introduce `Equinox', a small neural network library showing how a PyTorch-like class-based approach may be admitted without sacrificing JAX-like functional programming. We provide two main ideas. One: parameterised functions are themselves represented as `PyTrees', which means that the parameterisation of a function is transparent to the JAX framework. Two: we filter a PyTree to isolate just those components that should be treated when transforming (`jit', `grad' or `vmap'-ing) a higher-order function of a parameterised function -- such as a loss function applied to a model. Overall Equinox resolves the above tension without introducing any new programmatic abstractions: only PyTrees and transformations, just as with regular JAX. Equinox is available at \url{https://github.com/patrick-kidger/equinox}.

This paper has not been read by Pith yet.

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Forward citations

Cited by 11 Pith papers

Reviewed papers in the Pith corpus that reference this work. Sorted by Pith novelty score.

  1. Provable Data Scaling Law for Meta Learning via Complexity Minimization

    stat.ML 2026-06 unverdicted novelty 7.0

    A novel complexity minimization meta-learning framework provably demonstrates that few-shot adaptation error decreases as meta-training data volume increases.

  2. Observer-robust energy condition verification for warp drive spacetimes

    gr-qc 2026-02 accept novelty 7.0

    Warpax toolkit demonstrates that observer-robust optimization finds more extensive and severe energy-condition violations in warp drive metrics than single-frame Eulerian analysis.

  3. AMIGO: a Data-Driven Calibration of the JWST Interferometer

    astro-ph.IM 2025-10 unverdicted novelty 7.0

    AMIGO is an end-to-end differentiable forward model of JWST AMI that corrects detector systematics to recover high-precision astrometry and detect close high-contrast companions.

  4. Learning partially observed systems with neural Hamiltonian ordinary differential equations

    cs.LG 2026-05 unverdicted novelty 6.0

    NHODE framework learns partially observed dynamical systems by combining Hamiltonian neural networks with neural ODEs, enforcing energy conservation and improving long-horizon stability over data-driven baselines on m...

  5. Convex Optimization for Alignment and Preference Learning on a Single GPU

    cs.LG 2026-05 unverdicted novelty 6.0

    COALA applies convex optimization reformulations of neural networks to direct preference optimization, claiming single-GPU training with ~18% of DPO's TFLOPs and competitive performance on multiple datasets and models...

  6. Closed-form predictive coding via hierarchical Gaussian filters

    cs.LG 2026-05 unverdicted novelty 6.0

    Predictive coding is recast as deep hierarchical Gaussian filters to restore precision-weighted message passing, yielding closed-form inference and online precision learning that matches backpropagation speed on Fashi...

  7. A Unifying Framework for Parallelizing Sequential Models with Linear Dynamical Systems

    cs.LG 2025-09 unverdicted novelty 6.0

    A framework based on linear dynamical systems unifies fixed-point iteration schemes such as Newton, Picard, and Jacobi as approximate linearizations of nonlinear recursions for parallelizing sequential models.

  8. On the boundary cost of source-consistent warp shells

    gr-qc 2026-05 unverdicted novelty 5.0

    Source-consistent warp shells fail energy conditions at the source-vacuum boundary in all examined constructions and parameter scans.

  9. GCImOpt: Learning efficient goal-conditioned policies by imitating optimal trajectories

    cs.RO 2026-04 unverdicted novelty 5.0

    GCImOpt trains compact goal-conditioned neural policies by imitating efficiently generated optimal trajectories, achieving high success rates and near-optimal performance on cart-pole, quadcopter, and robot arm tasks ...

  10. Uncertainty in Physics and AI: Taxonomy, Quantification, and Validation

    stat.ML 2026-05 accept novelty 4.0

    A unified taxonomy of uncertainty in ML for physics is introduced together with validation tools such as coverage, calibration, and proper scoring rules, illustrated on regression and classification tasks.

  11. jNO: A JAX Library for Neural Operator and Foundation Model Training

    cs.LG 2026-05 unverdicted novelty 4.0

    jNO introduces a unified JAX tracing system for data-driven and physics-informed neural operator training that compiles domains, residuals, losses, and diagnostics into one pipeline.