pith. sign in

arxiv: 2606.22932 · v1 · pith:GWT7BTE6new · submitted 2026-06-22 · 💻 cs.LG

FORGE: Fused On-Register Gradient Elimination for Memory-Efficient LLM Training

Pith reviewed 2026-06-26 09:08 UTC · model grok-4.3

classification 💻 cs.LG
keywords memory-efficient traininggradient fusionoptimizer integrationLLM trainingregister-level computationtensor parallelismmixed-precision optimization
0
0 comments X

The pith

Fusing the optimizer step into the backward pass eliminates the need to store full gradient tensors.

A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.

The paper argues that writing every gradient to memory before the optimizer reads it is an artifact of separating differentiation from the update step, not a requirement of learning itself. It folds the optimizer directly into the backward pass so each gradient tile is applied the moment it is computed and never leaves the registers. In full precision this produces the identical update for any element-wise optimizer rule, and the identity holds under tensor and sequence sharding. The approach therefore removes the memory peak at the gradient-optimizer boundary while preserving the computed weights exactly in full precision and faithfully in lower precision.

Core claim

Reverse-mode differentiation computes every weight gradient, writes it to memory, and only then lets the optimizer read it back. FORGE folds the optimizer step into the backward pass and applies it one tile at a time, entirely in registers, so each gradient tile is consumed the instant it is produced and never becomes a tensor. In full precision the fused step is provably exact—the identical optimizer update, for every element-wise rule—and that exactness survives tensor- and sequence-parallel sharding. In the bf16 and 8-bit regimes the deviation is bounded and, for the weight store, rendered unbiased by stochastic rounding. Because each gradient tile is born and consumed in the same registe

What carries the argument

On-register fusion of gradient computation with the immediate element-wise optimizer update, which consumes each tile without materializing the full gradient tensor.

If this is right

  • Memory required for the optimizer step is more than halved.
  • At the small batch sizes typical of fine-tuning, training runs about 1.5 times faster.
  • When integrated with tensor-parallel sharding, the same hardware can support four times larger micro-batches for an 8B model.
  • Full-precision fidelity is preserved because gradients never undergo the down-conversion required for storage in other low-precision schemes.
  • The method applies to any linear layer under any element-wise optimizer rule.

Where Pith is reading between the lines

These are editorial extensions of the paper, not claims the author makes directly.

  • The tile-wise consumption pattern could be extended to other element-wise operations that occur after the backward pass if they admit register-level fusion.
  • Lower memory pressure on gradients might allow activation checkpointing to be relaxed in some training regimes without increasing peak memory.
  • The avoidance of extra quantization steps for stored gradients could reduce accumulated rounding error in long training runs.
  • Because the fusion survives sequence parallelism, it may directly benefit long-context training where activation memory is already high.

Load-bearing premise

Every optimizer rule used in practice must be strictly element-wise so that partial tiles can be updated independently without changing the final result.

What would settle it

A side-by-side execution of one optimizer step on a small model in full precision, comparing the exact weight values produced by the standard two-phase method versus the fused method.

Figures

Figures reproduced from arXiv: 2606.22932 by Aik Beng Ng, Avinash Anand, Bapi Chatterjee, Dikshant Kukreja, Erik Cambria, Kritarth Prasad, Simon See, Timothy Liu, Zhengkui Wang.

Figure 1
Figure 1. Figure 1: FORGE (8-bit state) on Llama-3.1-8B (H200, BS=1, SEQ=512, BF16-everywhere) cuts peak [PITH_FULL_IMAGE:figures/full_fig_p002_1.png] view at source ↗
Figure 2
Figure 2. Figure 2: Measured within-step GPU memory (memory_allocated) over one forward+backward+optimizer cycle (Llama-3.1-8B). Fused AdamW (red) holds its peak through backward and the optimizer step, where the gradient and fp32 moments are co-resident (∼64 GB); FORGE (blue) sheds memory through backward—each gradient tile is consumed in place—and never spikes. With activation checkpointing, FORGE +AC (green) stays flat at … view at source ↗
Figure 3
Figure 3. Figure 3: FORGE mechanism: per-tile data flow inside the fused backward-and-optimizer step. The weight [PITH_FULL_IMAGE:figures/full_fig_p004_3.png] view at source ↗
Figure 4
Figure 4. Figure 4: Convergence parity on continued pretraining over OpenMathInstruct-2 (BS [PITH_FULL_IMAGE:figures/full_fig_p007_4.png] view at source ↗
Figure 5
Figure 5. Figure 5: Headline single-GPU comparison on Llama-3.1-8B (H200 141 GB, BS=1, SEQ=512, BF16- [PITH_FULL_IMAGE:figures/full_fig_p008_5.png] view at source ↗
Figure 6
Figure 6. Figure 6: From-scratch pretraining parity on GPT-2 124M (nanoGPT) over FineWeb-Edu (sample-10BT, [PITH_FULL_IMAGE:figures/full_fig_p009_6.png] view at source ↗
Figure 7
Figure 7. Figure 7: Composability with FP8 and activation checkpointing at BS [PITH_FULL_IMAGE:figures/full_fig_p011_7.png] view at source ↗
Figure 8
Figure 8. Figure 8: Multi-model scaling on H200 (BS=1, SEQ=512, BF16-everywhere): peak memory of fused AdamW vs. FORGE across the Qwen3 family and Llama-3.1-8B, with FORGE’s step-time speedup over fused AdamW labeled above each model. FORGE holds peak memory to 0.47–0.73× the fused-AdamW footprint and runs 1.07–1.52× faster; at Qwen3-14B fused AdamW OOMs at this configuration while FORGE runs at 150.1 ms / 63.2 GiB. optimizer… view at source ↗
Figure 9
Figure 9. Figure 9: Distributed sweeps on 8×A100, FORGE vs. fused-CUDA-AdamW, across model size, tensor-parallel degree, micro-batch, and sequence length. Top: per-rank peak-memory saving (20–55%), largest where optimizer state dominates and at small batch/sequence. Bottom: step-time ratio vs. fused AdamW (the dashed line marks parity)—within 0.8–1.2× throughout—parity at no speed cost. Every reclaimed gigabyte is micro-batch… view at source ↗
Figure 10
Figure 10. Figure 10: Llama-style 8B on four A100-40GB (TP= 4, SEQ=1024): per-rank peak memory by micro￾batch. foreach AdamW OOMs at every micro-batch and fused-CUDA-AdamW fits only micro-batch 1 (34.0 GiB / 305.1 ms); FORGE stays under the 40 GiB card all the way to micro-batch 4 (24.6 GiB)—four times the micro-batch on the same hardware—at per-step latencies 241.3/346.7/504.0 ms for micro-batch 1/2/4. At micro-batch 4 FORGE … view at source ↗
Figure 11
Figure 11. Figure 11: Convergence parity across the remaining Qwen3 sizes on OpenMathInstruct-2 continued pretraining [PITH_FULL_IMAGE:figures/full_fig_p023_11.png] view at source ↗
Figure 12
Figure 12. Figure 12: Per-family training-loss curve (left sub-panel), peak GPU memory (center sub-panel), and average [PITH_FULL_IMAGE:figures/full_fig_p035_12.png] view at source ↗
Figure 13
Figure 13. Figure 13: Optimizer sweep (continued from Figure [PITH_FULL_IMAGE:figures/full_fig_p036_13.png] view at source ↗
Figure 14
Figure 14. Figure 14: Optimizer sweep (continued from Figure [PITH_FULL_IMAGE:figures/full_fig_p038_14.png] view at source ↗
read the original abstract

Reverse-mode differentiation computes every weight gradient, writes it to memory, and only then lets the optimizer read it back. This two-phase schedule sets the memory ceiling of modern training: at the seam between the phases, every layer's gradient is live at once. We argue that this materialized gradient is an artifact of how differentiation is staged, not a quantity that learning requires -- and we eliminate it. FORGE folds the optimizer step into the backward pass and applies it one tile at a time, entirely in registers, so each gradient tile is consumed the instant it is produced and never becomes a tensor. The fusion changes only when the update happens, not what it computes: in full precision the fused step is provably exact -- the identical optimizer update, for every element-wise rule -- and that exactness survives tensor- and sequence-parallel sharding; in the bf16 and 8-bit regimes used in practice it is faithful rather than bit-identical, its deviation bounded and, for the weight store, rendered unbiased by stochastic rounding. Because each gradient tile is born and consumed in the same registers, it is never converted down to bf16 to be stored and read back; FORGE thus preserves the full-precision fidelity that both bf16 and 8-bit optimizers lose to that conversion. Nor is the method tied to one architecture or one optimizer: linear layers are ubiquitous, and FORGE reclaims the gradient memory of any of them under any element-wise rule. Empirically FORGE more than halves the memory of an optimizer step and, at the small batch sizes typical of fine-tuning and continued pretraining, runs about 1.5x faster; integrated into tensor-parallel Megatron-LM it fits 8B training at four times the micro-batch a standard optimizer allows on the same GPUs.

Editorial analysis

A structured set of objections, weighed in public.

Desk editor's note, referee report, simulated authors' rebuttal, and a circularity audit. Tearing a paper down is the easy half of reading it; the pith above is the substance, this is the friction.

Referee Report

2 major / 2 minor

Summary. The paper introduces FORGE, a technique that fuses element-wise optimizer updates into the backward pass of reverse-mode autodiff, processing gradients tile-by-tile entirely in registers so that no full gradient tensor is ever materialized. It claims that in full precision the result is mathematically identical to the standard two-phase schedule for any element-wise optimizer rule, that this equivalence is preserved under tensor and sequence parallelism, and that in reduced-precision regimes the deviation is bounded and unbiased for the weights via stochastic rounding. Empirically it reports more than 2x reduction in optimizer memory and ~1.5x speedup at small batch sizes, enabling 4x larger micro-batches for 8B models inside Megatron-LM.

Significance. If the exactness claim holds, the work directly attacks the gradient-materialization bottleneck that dominates memory usage in LLM training. The ability to reclaim gradient memory for any linear layer under arbitrary element-wise rules, while preserving full-precision fidelity that bf16/8-bit optimizers normally lose, would be a practical advance for fine-tuning and continued pretraining. The sharding invariance is a particularly useful property for distributed training.

major comments (2)
  1. [Abstract / correctness argument] Abstract and the central correctness argument: the manuscript asserts that the fused on-register update is provably identical to the standard optimizer step for every element-wise rule, yet the explicit derivation (including the handling of tile boundaries, accumulation order, and why no cross-tile dependencies arise) is not supplied. This equivalence is the load-bearing claim for both the full-precision guarantee and the sharding invariance.
  2. [Abstract / scope of element-wise rules] The weakest assumption noted in the stress test—that every practical optimizer rule is strictly element-wise and that tiling introduces neither numerical nor synchronization side-effects—receives no counter-example analysis or formal statement of the class of supported rules (e.g., whether Adam’s second-moment state or any non-element-wise preconditioner would break the argument).
minor comments (2)
  1. [Empirical results] Reported speedups and memory numbers lack error bars, baseline implementation details (exact Megatron version, optimizer hyperparameters), and hardware specifications.
  2. [Method description] Notation for “tile” and “register” residency is used without a precise definition or pseudocode that would allow reproduction of the fusion schedule.

Simulated Author's Rebuttal

2 responses · 0 unresolved

We thank the referee for the positive summary and for identifying the two points that require strengthening. Both comments correctly note that the abstract states strong claims without the supporting formal material; we will revise the manuscript to supply the requested derivation and formal scope statement.

read point-by-point responses
  1. Referee: [Abstract / correctness argument] Abstract and the central correctness argument: the manuscript asserts that the fused on-register update is provably identical to the standard optimizer step for every element-wise rule, yet the explicit derivation (including the handling of tile boundaries, accumulation order, and why no cross-tile dependencies arise) is not supplied. This equivalence is the load-bearing claim for both the full-precision guarantee and the sharding invariance.

    Authors: We agree that the explicit derivation is absent from the current manuscript. Section 3 sketches the argument at a high level but does not provide the tile-by-tile expansion, accumulation-order invariance, or proof that element-wise independence precludes cross-tile data dependencies. We will add a dedicated subsection (or appendix) containing the full derivation: because each optimizer step is a per-element function of the corresponding gradient entry, weight entry, and optimizer state, the result of applying the update to a tile is independent of all other tiles; hence register-level processing yields a mathematically identical weight update regardless of tiling or sharding. The same independence immediately implies invariance under tensor and sequence parallelism. Revision will be made. revision: yes

  2. Referee: [Abstract / scope of element-wise rules] The weakest assumption noted in the stress test—that every practical optimizer rule is strictly element-wise and that tiling introduces neither numerical nor synchronization side-effects—receives no counter-example analysis or formal statement of the class of supported rules (e.g., whether Adam’s second-moment state or any non-element-wise preconditioner would break the argument).

    Authors: We will add both the requested formal statement and counter-example analysis. The supported class is defined as any optimizer whose update for each parameter is a strictly element-wise function of that parameter’s gradient, current value, and per-parameter optimizer state (explicitly including Adam’s first- and second-moment buffers, which remain element-wise). We will state that non-element-wise preconditioners (e.g., those requiring matrix inverses or cross-parameter statistics) fall outside the class and cannot be fused without materializing the full gradient; a short counter-example paragraph will illustrate the breakage. Because all operations remain local to a tile, no additional synchronization is introduced. Revision will be made. revision: yes

Circularity Check

0 steps flagged

No significant circularity

full rationale

The paper's central claim asserts that the fused on-register update produces the identical optimizer step (for any element-wise rule) in full precision, with the property preserved under sharding, because only the timing of the update is altered. This follows directly from the definition of element-wise arithmetic and the absence of cross-tile dependencies in the stated scope; no equations, fitted parameters, or self-citation chains are invoked to derive the exactness result. The argument is scoped to implementation mechanics of gradient materialization and does not reduce the claimed identity to any input quantity by construction.

Axiom & Free-Parameter Ledger

0 free parameters · 1 axioms · 0 invented entities

The approach rests on the domain assumption that optimizer rules are element-wise and that register-level tiling preserves the mathematical result; no free parameters or new entities are introduced.

axioms (1)
  • domain assumption Optimizer update rules are strictly element-wise operations that commute with tiling.
    Invoked to guarantee that per-tile application yields the identical global update.

pith-pipeline@v0.9.1-grok · 5892 in / 1161 out tokens · 17820 ms · 2026-06-26T09:08:46.237562+00:00 · methodology

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Reference graph

Works this paper leans on

14 extracted references · 1 canonical work pages

  1. [1]

    update-then-sum

    doi: 10.48550/arxiv.2406.16793. Jiawei Zhao, Zhenyu Zhang, Beidi Chen, Zhangyang Wang, Anima Anandkumar, and Yuandong Tian. GaLore: Memory-efficient LLM training by gradient low-rank projection. InInternational Conference on Machine Learning, 2024. doi: 10.48550/arxiv.2403.03507. Yanli Zhao, Andrew Gu, Rohan Varma, Liang Luo, Chien-Chin Huang, Min Xu, Les...

  2. [2]

    AdamW (Loshchilov & Hutter, 2019): first and second moments, decoupled weight decay

  3. [3]

    NAdam: AdamW with Nesterov momentum correction onˆm

  4. [4]

    RAdam: AdamW with rectified-variance term

  5. [5]

    Lion: sign-of-momentum update with a single moment buffer

  6. [6]

    RMSprop: second moment only, no first moment

  7. [7]

    AdaGrad: cumulative squared gradient (no decay)

  8. [8]

    SGD: vanilla gradient descent (no moment buffers)

  9. [9]

    SGD with momentum: single moment, no second moment. (ii) Linear-layer fusion only (standard optimizer step).These five couple coordinates through a row, column, block, or global statistic, so the update cannot be formed inside one tile; FORGE runs the tiled linear-layer backward and applies the optimizer as a standard step. They are included to show the l...

  10. [10]

    Adam-mini (Zhang et al., 2025): AdamW with shared per-block second moment

  11. [11]

    Adafactor (Shazeer & Stern, 2018): factored (row×column) second-moment statistics

  12. [12]

    AdaLomo (Lv et al., 2024): low-memory AdamW with factored second moment

  13. [13]

    LAMB: layer-wise adaptive scaling (per-tensor trust ratio)

  14. [14]

    SM3: per-dimension running max for the cumulative moment. 33 Each fully-fused family replaces the body of Phase 2 of Algorithm 1 with∼10 lines of Triton arithmetic; the wgrad mainloop (Phase 1) and the state read/write pattern (Phase 3) are unchanged. Cross-element preconditioners (Muon’s Newton–Schulz orthogonalization (Jordan et al., 2024), Shampoo, SOA...