pith. sign in

arxiv: 2402.14212 · v4 · pith:RDJZWGC2new · submitted 2024-02-22 · 💻 cs.LG · cs.AI

Moonwalk: Inverse-Forward Differentiation

Pith reviewed 2026-05-25 08:38 UTC · model grok-4.3

classification 💻 cs.LG cs.AI
keywords inverse-forward differentiationsubmersive networksfragmental gradient checkpointingvector-inverse-Jacobian productmemory-efficient trainingdeep network depthmixed-mode differentiation
0
0 comments X

The pith

A differentiation method reconstructs gradients in a forward sweep to train neural networks more than twice as deep under the same memory limit as backpropagation.

A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.

Standard backpropagation stores all intermediate activations during the forward pass, which caps how deep a network can be trained under a fixed memory budget. The paper defines submersive networks whose layer Jacobians have trivial cokernels, allowing exact gradient reconstruction during a forward sweep. For layers that are not submersive it records only the minimal residuals needed via fragmental checkpointing. A mixed-mode algorithm first runs a memory-efficient reverse pass for input gradients then uses a new vector-inverse-Jacobian product operator to recover parameter gradients forward, matching backpropagation runtime while removing the activation-storage requirement.

Core claim

The paper shows that gradients can be reconstructed exactly without stored activations for submersive networks by inverting gradient flow outside each Jacobian's cokernel with the vector-inverse-Jacobian product, and that fragmental gradient checkpointing restores the erased cotangents for non-submersive layers by storing only the necessary minimal subset of residuals; the resulting mixed-mode procedure therefore eliminates activation storage while preserving exact gradients.

What carries the argument

The vector-inverse-Jacobian product (vijp), which inverts gradient flow outside the cokernel of each layer's Jacobian to enable exact forward reconstruction of parameter gradients.

If this is right

  • Networks can be trained more than twice as deep under the same memory budget.
  • Runtime remains comparable to standard backpropagation.
  • Exact gradients are obtained without storing all activations for submersive layers.
  • Only a minimal subset of residuals needs to be recorded for non-submersive layers.

Where Pith is reading between the lines

These are editorial extensions of the paper, not claims the author makes directly.

  • The same forward-reconstruction idea could be tested on recurrent or stateful architectures where activation storage is especially costly.
  • If the vijp operator generalizes cheaply to new layer types, memory savings could be applied to training on edge devices or with very large batch sizes.
  • The cokernel view might suggest redesigning layers themselves to become submersive and thereby remove checkpointing entirely.

Load-bearing premise

The vector-inverse-Jacobian product can be computed at a cost comparable to a standard vector-Jacobian product and the cokernel analysis holds for the layers used in practice.

What would settle it

Run the method on a network more than twice as deep as the backpropagation baseline under identical memory and measure whether wall-clock time matches backpropagation while final gradients and training loss match to machine precision.

Figures

Figures reproduced from arXiv: 2402.14212 by Armin Karamzade, Dmitrii Krylov, Roy Fox.

Figure 1
Figure 1. Figure 1: The computation flow diagram of Moonwalk: (a) [PITH_FULL_IMAGE:figures/full_fig_p004_1.png] view at source ↗
Figure 2
Figure 2. Figure 2: Maximum allocated memory during training. The [PITH_FULL_IMAGE:figures/full_fig_p005_2.png] view at source ↗
Figure 3
Figure 3. Figure 3: Maximum allocated memory during training for a [PITH_FULL_IMAGE:figures/full_fig_p006_3.png] view at source ↗
Figure 5
Figure 5. Figure 5: Train accuracy of three models trained with [PITH_FULL_IMAGE:figures/full_fig_p006_5.png] view at source ↗
Figure 6
Figure 6. Figure 6: L2 gradient error between true gradients computed [PITH_FULL_IMAGE:figures/full_fig_p007_6.png] view at source ↗
read the original abstract

Backpropagation's main limitation is its need to store intermediate activations (residuals) during the forward pass, which restricts the depth of trainable networks. This raises a fundamental question: can we avoid storing these activations? We address this by revisiting the structure of gradient computation. Backpropagation computes gradients through a sequence of vector-Jacobian products, an operation that is generally irreversible. The lost information lies in the cokernel of each layer's Jacobian. We define submersive networks -- networks whose layer Jacobians have trivial cokernels -- in which gradients can be reconstructed exactly in a forward sweep without storing activations. For non-submersive layers, we introduce fragmental gradient checkpointing, which records only the minimal subset of residuals necessary to restore the cotangents erased by the Jacobian. Central to our approach is a novel operator, the vector-inverse-Jacobian product (vijp), which inverts gradient flow outside the cokernel. Our mixed-mode algorithm first computes input gradients with a memory-efficient reverse pass, then reconstructs parameter gradients in a forward sweep using the vijp, eliminating the need to store activations. We implement this method in Moonwalk and show that it matches backpropagation's runtime while training networks more than twice as deep under the same memory budget.

Editorial analysis

A structured set of objections, weighed in public.

Desk editor's note, referee report, simulated authors' rebuttal, and a circularity audit. Tearing a paper down is the easy half of reading it; the pith above is the substance, this is the friction.

Referee Report

2 major / 1 minor

Summary. The paper introduces Moonwalk, a mixed-mode inverse-forward differentiation algorithm that computes input gradients via a memory-efficient reverse pass and reconstructs parameter gradients in a forward sweep using a new vector-inverse-Jacobian product (vijp) operator. It defines submersive networks (those with layer Jacobians having trivial cokernels) for exact activation-free gradient reconstruction and introduces fragmental gradient checkpointing for non-submersive layers, claiming to match backpropagation runtime while training networks more than twice as deep under identical memory budgets.

Significance. If the runtime equivalence and memory scaling claims hold, the work would enable substantially deeper networks without proportional increases in activation storage, addressing a core scalability bottleneck in deep learning. The vijp operator and submersive-network framing constitute a distinct algorithmic contribution beyond standard checkpointing or reversible architectures.

major comments (2)
  1. Abstract: the headline claim that Moonwalk 'matches backpropagation's runtime' while supporting '>2x deeper' networks under fixed memory requires that vijp evaluation cost is bounded by a small constant factor over VJP and that cokernel analysis applies to layers used in practice; neither a complexity statement nor layer-by-layer verification (linear, conv, ReLU) is supplied to support this load-bearing premise.
  2. Abstract: the assertion that fragmental gradient checkpointing 'records only the minimal subset of residuals' for non-submersive layers is not accompanied by any quantification of residual size or empirical memory scaling; if common layers possess non-trivial cokernels, the stored residuals could prevent the claimed doubling of depth at constant memory.
minor comments (1)
  1. The abstract would be strengthened by naming the specific architectures and depths used in the Moonwalk experiments that underpin the '>2x deeper' claim.

Simulated Author's Rebuttal

2 responses · 0 unresolved

We thank the referee for their constructive feedback. We address the two major comments below and will revise the manuscript accordingly to provide the requested supporting analysis and quantification.

read point-by-point responses
  1. Referee: Abstract: the headline claim that Moonwalk 'matches backpropagation's runtime' while supporting '>2x deeper' networks under fixed memory requires that vijp evaluation cost is bounded by a small constant factor over VJP and that cokernel analysis applies to layers used in practice; neither a complexity statement nor layer-by-layer verification (linear, conv, ReLU) is supplied to support this load-bearing premise.

    Authors: We agree that the abstract's claims would benefit from explicit supporting material. In the revision we will add a formal complexity statement for the vijp operator (showing its cost is bounded by a small constant factor over VJP) to Section 3 and include a new subsection with layer-by-layer cokernel verification for linear, convolutional, and ReLU layers. The abstract will be updated to reference this analysis. revision: yes

  2. Referee: Abstract: the assertion that fragmental gradient checkpointing 'records only the minimal subset of residuals' for non-submersive layers is not accompanied by any quantification of residual size or empirical memory scaling; if common layers possess non-trivial cokernels, the stored residuals could prevent the claimed doubling of depth at constant memory.

    Authors: The referee correctly notes the absence of quantification. We will revise the manuscript to add explicit quantification of residual sizes (tied to cokernel dimension) for non-submersive layers and include new empirical memory-scaling results in the experiments section demonstrating that the stored residuals remain small enough to support the >2x depth claim. The abstract will be adjusted to reflect these additions. revision: yes

Circularity Check

0 steps flagged

No circularity: algorithmic construction is independent of its inputs

full rationale

The paper defines submersive networks via the standard linear-algebra notion of trivial cokernel of the Jacobian, introduces vijp as a new operator that inverts gradient flow outside that cokernel, and presents fragmental checkpointing and the mixed-mode algorithm as explicit algorithmic choices. None of these steps reduce by definition or by self-citation to the performance claims; the runtime and memory assertions are empirical outcomes of the implementation rather than tautological consequences of the definitions. No equations equate a fitted quantity to a prediction, no uniqueness theorem is imported from prior self-work, and no ansatz is smuggled via citation. The derivation chain therefore remains self-contained against external benchmarks.

Axiom & Free-Parameter Ledger

0 free parameters · 1 axioms · 3 invented entities

The approach introduces three new conceptual objects whose independent empirical support is not supplied in the abstract.

axioms (1)
  • domain assumption Layer Jacobians admit a well-defined cokernel whose dimension determines information loss during gradient flow.
    Invoked to define submersive networks and to justify exact forward reconstruction.
invented entities (3)
  • submersive network no independent evidence
    purpose: Network whose layer Jacobians have trivial cokernels so gradients can be recovered forward without stored activations.
    New definition introduced to enable the memory-saving claim.
  • vector-inverse-Jacobian product (vijp) no independent evidence
    purpose: Operator that inverts gradient flow outside the cokernel.
    Novel operator required for the forward reconstruction step.
  • fragmental gradient checkpointing no independent evidence
    purpose: Stores only the minimal residuals needed to restore erased cotangents for non-submersive layers.
    New checkpointing variant required for general networks.

pith-pipeline@v0.9.0 · 5756 in / 1434 out tokens · 29300 ms · 2026-05-25T08:38:55.474051+00:00 · methodology

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Reference graph

Works this paper leans on

28 extracted references · 28 canonical work pages · 3 internal anchors

  1. [1]

    write newline

    " write newline "" before.all 'output.state := FUNCTION n.dashify 't := "" t empty not t #1 #1 substring "-" = t #1 #2 substring "--" = not "--" * t #2 global.max substring 't := t #1 #1 substring "-" = "-" * t #2 global.max substring 't := while if t #1 #1 substring * t #2 global.max substring 't := if while FUNCTION format.date year duplicate empty "emp...

  2. [2]

    Abadi, M., Agarwal, A., Barham, P., Brevdo, E., Chen, Z., Citro, C., Corrado, G. S., Davis, A., Dean, J., Devin, M., Ghemawat, S., Goodfellow, I., Harp, A., Irving, G., Isard, M., Jia, Y., Jozefowicz, R., Kaiser, L., Kudlur, M., Levenberg, J., Man\' e , D., Monga, R., Moore, S., Murray, D., Olah, C., Schuster, M., Shlens, J., Steiner, B., Sutskever, I., T...

  3. [3]

    Agarap, A. F. Deep learning using rectified linear units (relu). arXiv preprint arXiv:1803.08375, 2018

  4. [4]

    G., Pearlmutter, B

    Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., and Torr, P. Gradients without backpropagation, 2022

  5. [5]

    T., Duvenaud, D., and Jacobsen, J.-H

    Behrmann, J., Grathwohl, W., Chen, R. T., Duvenaud, D., and Jacobsen, J.-H. Invertible residual networks. In International conference on machine learning, pp.\ 573--582. PMLR, 2019

  6. [6]

    J., Leary, C., Maclaurin, D., Necula, G., Paszke, A., Vander P las, J., Wanderman- M ilne, S., and Zhang, Q

    Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., Necula, G., Paszke, A., Vander P las, J., Wanderman- M ilne, S., and Zhang, Q. JAX : composable transformations of P ython+ N um P y programs, 2018. URL http://github.com/google/jax

  7. [7]

    R., Porzi, L., and Kontschieder, P

    Bulo, S. R., Porzi, L., and Kontschieder, P. In-place activated batchnorm for memory-optimized training of dnns. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp.\ 5639--5647, 2018

  8. [8]

    and Moseley, B

    Chakrabarti, A. and Moseley, B. Backprop with approximate activations for memory-efficient network training. CoRR, abs/1901.07988, 2019. URL http://arxiv.org/abs/1901.07988

  9. [9]

    Training Deep Nets with Sublinear Memory Cost

    Chen, T., Xu, B., Zhang, C., and Guestrin, C. Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174, 2016

  10. [10]

    NICE: Non-linear Independent Components Estimation

    Dinh, L., Krueger, D., and Bengio, Y. Nice: Non-linear independent components estimation. arXiv preprint arXiv:1410.8516, 2014

  11. [11]

    Can forward gradient match backpropagation?, 2023

    Fournier, L., Rivaud, S., Belilovsky, E., Eickenberg, M., and Oyallon, E. Can forward gradient match backpropagation?, 2023

  12. [12]

    N., Ren, M., Urtasun, R., and Grosse, R

    Gomez, A. N., Ren, M., Urtasun, R., and Grosse, R. B. The reversible residual network: Backpropagation without storing activations. In Guyon, I., Luxburg, U. V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017. URL https://proceedings...

  13. [13]

    Memory-efficient backpropagation through time

    Gruslys, A., Munos, R., Danihelka, I., Lanctot, M., and Graves, A. Memory-efficient backpropagation through time. Advances in neural information processing systems, 29, 2016

  14. [14]

    and Szegedy, C

    Ioffe, S. and Szegedy, C. Batch normalization: Accelerating deep network training by reducing internal covariate shift, 2015

  15. [15]

    W., and Oyallon, E

    Jacobsen, J.-H., Smeulders, A. W., and Oyallon, E. i-revnet: Deep invertible networks. In International Conference on Learning Representations, 2018

  16. [16]

    M., Osindero, S., Vinyals, O., Graves, A., Silver, D., and Kavukcuoglu, K

    Jaderberg, M., Czarnecki, W. M., Osindero, S., Vinyals, O., Graves, A., Silver, D., and Kavukcuoglu, K. Decoupled neural interfaces using synthetic gradients. In International conference on machine learning, pp.\ 1627--1635. PMLR, 2017

  17. [17]

    Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization, 2017

  18. [18]

    Kingma, D. P. and Dhariwal, P. Glow: Generative flow with invertible 1x1 convolutions. Advances in neural information processing systems, 31, 2018

  19. [19]

    Cifar-10 (canadian institute for advanced research)

    Krizhevsky, A., Nair, V., and Hinton, G. Cifar-10 (canadian institute for advanced research). URL http://www.cs.toronto.edu/ kriz/cifar.html

  20. [20]

    MacKay, M., Vicol, P., Ba, J., and Grosse, R. B. Reversible recurrent neural networks. Advances in Neural Information Processing Systems, 31, 2018

  21. [21]

    Reversible vision transformers

    Mangalam, K., Fan, H., Li, Y., Wu, C.-Y., Xiong, B., Feichtenhofer, C., and Malik, J. Reversible vision transformers. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.\ 10830--10840, 2022

  22. [22]

    and Sutskever, I

    Martens, J. and Sutskever, I. Training deep and recurrent networks with hessian-free optimization. In Neural Networks: Tricks of the Trade: Second Edition, pp.\ 479--535. Springer, 2012

  23. [23]

    S., Bershatsky, D., Gusak, J., Shonenkov, A., Dimitrov, D

    Novikov, G. S., Bershatsky, D., Gusak, J., Shonenkov, A., Dimitrov, D. V., and Oseledets, I. Few-bit backward: Quantized gradients of activation functions for memory footprint reduction. In Krause, A., Brunskill, E., Cho, K., Engelhardt, B., Sabato, S., and Scarlett, J. (eds.), Proceedings of the 40th International Conference on Machine Learning, volume 2...

  24. [24]

    Pytorch: An imperative style, high-performance deep learning library, 2019

    Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., Desmaison, A., Köpf, A., Yang, E., DeVito, Z., Raison, M., Tejani, A., Chilamkurthy, S., Steiner, B., Fang, L., Bai, J., and Chintala, S. Pytorch: An imperative style, high-performance deep learning library, 2019

  25. [25]

    Scaling forward gradient with local losses

    Ren, M., Kornblith, S., Liao, R., and Hinton, G. Scaling forward gradient with local losses. In The Eleventh International Conference on Learning Representations, 2022

  26. [26]

    and Mohamed, S

    Rezende, D. and Mohamed, S. Variational inference with normalizing flows. In International conference on machine learning, pp.\ 1530--1538. PMLR, 2015

  27. [27]

    Learning by directional gradient descent

    Silver, D., Goyal, A., Danihelka, I., Hessel, M., and van Hasselt, H. Learning by directional gradient descent. In International Conference on Learning Representations, 2021

  28. [28]

    Williams, R. J. and Zipser, D. A learning algorithm for continually running fully recurrent neural networks. Neural Computation, 1 0 (2): 0 270--280, 1989. doi:10.1162/neco.1989.1.2.270