pith. sign in

arxiv: 2605.19269 · v2 · pith:AIQS2Q4Inew · submitted 2026-05-19 · 💻 cs.LG

CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programs

Pith reviewed 2026-05-21 07:17 UTC · model grok-4.3

classification 💻 cs.LG
keywords TransformerGEMMEpilogue fusionGPU kernelsMemory optimizationKernel abstractionDeep learning training
0
0 comments X

The pith

Transformer operators can be reparameterized to execute as epilogues while GEMM output tiles remain on chip.

A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.

The paper establishes that many surrounding operators in Transformer blocks can be algebraically rewritten to run as small programs attached to a matrix multiplication instead of as separate memory-bound kernels. By keeping the GEMM main loop fixed and exposing only a narrow set of epilogue operations for scaling, reductions, pairwise math, and accumulation, the approach avoids writing large intermediate tensors back to global memory. A sympathetic reader cares because these operators otherwise dominate end-to-end time once the core linear algebra is already highly optimized. The abstraction is shown to be expressive enough for nearly all non-attention work in both forward and backward passes while still delivering the performance of hand-tuned GEMM kernels.

Core claim

CODA fixes the GEMM mainloop and exposes a small set of composable epilogue primitives so that normalization, activations, residual updates, and related computations execute while a GEMM output tile is still resident on chip, before any write to global memory. This reparameterization covers nearly all non-attention computation in the forward and backward pass of a standard Transformer block and preserves the performance structure of expert-written GEMMs.

What carries the argument

The GEMM-plus-epilogue programming model that fixes the mainloop and provides composable primitives for scaling, reductions, pairwise transformations, and accumulation.

If this is right

  • Nearly all non-attention work in forward and backward passes of a standard Transformer block fits inside the epilogue interface.
  • Both human-written and LLM-written CODA kernels reach high performance on representative workloads.
  • GEMM-plus-epilogue programming provides a practical route to framework productivity while retaining hardware efficiency.

Where Pith is reading between the lines

These are editorial extensions of the paper, not claims the author makes directly.

  • The same on-chip epilogue style might be applied to attention blocks if the mainloop can be extended without losing GEMM efficiency.
  • Automated code generators could target the constrained epilogue interface to produce kernels for new model variants without manual tuning.

Load-bearing premise

The small set of epilogue primitives is expressive enough to cover nearly all non-attention computation without forcing the GEMM mainloop to be rewritten.

What would settle it

Measure end-to-end training throughput of a full Transformer block implemented entirely with CODA kernels versus the same block using separate framework kernels for normalization, activations, and residuals.

Figures

Figures reproduced from arXiv: 2605.19269 by Arjun Menon, Driss Guessous, Han Guo, Jack Zhang, Tri Dao, Vijay Thakkar, Yoon Kim.

Figure 1
Figure 1. Figure 1: Runtime breakdown for LLaMA-3-style 1B model training on a single H100 using TorchTitan. LLM training has become just as much of a systems prob￾lem as a modeling one. FLOPs in modern Transformer￾based LLMs are dominated by matrix multiplications and at￾tention, whose kernels have been heavily optimized for Ten￾sor Core execution. Yet Transformers, and deep learning architectures more broadly, also contain … view at source ↗
Figure 2
Figure 2. Figure 2: Forward pass of a standard Transformer layer. The top row shows the canonical formulation, [PITH_FULL_IMAGE:figures/full_fig_p002_2.png] view at source ↗
Figure 3
Figure 3. Figure 3: A GEMM mainloop com￾putes output tiles; the epilogue trans￾forms each tile before the final global￾memory store. The epilogue is a natural place to implement fusions be￾cause the output of the matmul is already present on chip close to compute cores. Practical epilogues commonly per￾form scaling, bias addition, activations, residual updates, data type conversions, tile-wise reductions and other output elem… view at source ↗
Figure 4
Figure 4. Figure 4: GEMM-RMSNorm-GEMM reparameterization. We address the reduction by splitting it into two levels. The first GEMM epilogue computes tile-local partial reductions, and a small auxiliary ker￾nel reduces these partials across tiles to obtain r. Since the auxiliary ker￾nel reads a few partial values per tile rather than the full activation tensor, its memory traffic is much smaller than that of a standalone RMSNo… view at source ↗
Figure 5
Figure 5. Figure 5: Benchmarks Here r = 1/ p reduce(rb) + ϵ is computed by a small auxiliary reduction over the tile partials. This decomposition replaces a stan￾dalone RMSNorm kernel with tile-local epilogue work around the two GEMMs, plus a lightweight auxiliary reduction. In [PITH_FULL_IMAGE:figures/full_fig_p005_5.png] view at source ↗
Figure 6
Figure 6. Figure 6: Relative error. Numerics. The reparameterization changes where the RMSNorm scale is applied: the row-wise factor r is delayed from before the second GEMM to the second GEMM epilogue. We compare BF16 GEMM￾RMSNorm-GEMM outputs against an FP32 reference on Llama-3 8B layers. We report the errors of CODA and QuACK, on which our GEMM template is based, normalized by the error of the standard PyTorch path [PITH… view at source ↗
Figure 7
Figure 7. Figure 7: Pairwise activations op￾erate on local feature pairs in the GEMM epilogue. This form captures several operations in Transformer blocks: • RoPE rotates each feature pair and return two outputs; • SwiGLU combines gate and value stream into one output; • SwiGLU backward pass maps one incoming gradient into gra￾dients for both paired inputs. Pairwise activations couple neighboring feature lanes and may change … view at source ↗
Figure 8
Figure 8. Figure 8: Kernel-level speedups for representative GEMM-plus-epilogue primitives across [PITH_FULL_IMAGE:figures/full_fig_p006_8.png] view at source ↗
Figure 9
Figure 9. Figure 9: Forward and backward fusion for GEMM–epilogue blocks. Forward epilogues attach to the [PITH_FULL_IMAGE:figures/full_fig_p007_9.png] view at source ↗
Figure 10
Figure 10. Figure 10: Kernel-level speedups on reparameterized Transformer kernels relative to cuBLAS with [PITH_FULL_IMAGE:figures/full_fig_p008_10.png] view at source ↗
Figure 11
Figure 11. Figure 11: Block-level speedups for reparameterized Transformer kernel sequences, including [PITH_FULL_IMAGE:figures/full_fig_p009_11.png] view at source ↗
read the original abstract

Transformer training systems are built around dense linear algebra, yet a nontrivial fraction of end-to-end time is spent on surrounding memory-bound operators. Normalization, activations, residual updates, reductions, and related computations repeatedly move large intermediate tensors through global memory while performing little arithmetic, making data movement an increasingly important bottleneck in otherwise highly optimized training stacks. We introduce CODA, a GPU kernel abstraction that expresses these computations as GEMM-plus-epilogue programs. CODA is based on the observation that many Transformer operators exposed as separate framework kernels can be algebraically reparameterized to execute while a GEMM output tile remains on chip, before it is written to memory. The abstraction fixes the GEMM mainloop and exposes a small set of composable epilogue primitives for scaling, reductions, pairwise transformations, and accumulation. This constrained interface preserves the performance structure of expert-written GEMMs while remaining expressive enough to cover nearly all non-attention computation in the forward and backward pass of a standard Transformer block. Across representative Transformer workloads, both human- and LLM-authored CODA kernels achieve high performance, suggesting that GEMM-plus-epilogue programming offers a practical path toward combining framework-level productivity with hardware-level efficiency.

Editorial analysis

A structured set of objections, weighed in public.

Desk editor's note, referee report, simulated authors' rebuttal, and a circularity audit. Tearing a paper down is the easy half of reading it; the pith above is the substance, this is the friction.

Referee Report

1 major / 2 minor

Summary. The paper introduces CODA, a GPU kernel abstraction that rewrites many non-attention Transformer operators (normalization, activations, residuals, reductions) as GEMM-plus-epilogue programs. The core observation is that these operators can be algebraically reparameterized to execute on GEMM output tiles while they remain on-chip, before writing to global memory. The abstraction fixes the GEMM mainloop and exposes a constrained set of composable epilogue primitives (scaling, reductions, pairwise transformations, accumulation) that is claimed to cover nearly all non-attention computation in standard Transformer forward and backward passes while preserving expert GEMM performance. Both human- and LLM-authored CODA kernels are reported to achieve high performance across representative workloads.

Significance. If the algebraic reparameterization and epilogue expressiveness hold, the work offers a practical route to fusing memory-bound operators with high-performance GEMMs, reducing data movement in Transformer training stacks. The parameter-free algebraic approach and the constrained yet composable epilogue interface are strengths that could improve both productivity and efficiency over ad-hoc kernel fusion. The manuscript's emphasis on LLM-authored kernels also highlights a potential path toward automated kernel generation.

major comments (1)
  1. [§3.2 and §4.2] §3.2 (Epilogue Primitives) and §4.2 (Norm Fusion): The reduction primitive is presented as operating on the GEMM output tile to enable on-chip LayerNorm/RMSNorm. However, standard GEMM tiling produces output tiles whose N-dimension (typically 128) is much smaller than the hidden dimension (e.g., 4096). The text does not specify how partial row reductions for mean/variance are accumulated across tiles using only registers or shared memory without intermediate global writes; if cross-tile communication requires extra memory traffic, the claimed elimination of separate memory-bound kernels for per-token norms does not hold.
minor comments (2)
  1. [Abstract and §5] The abstract and results sections state that CODA kernels achieve 'high performance' but provide no quantitative metrics, baselines, or error bars; adding a table with speedups or roofline comparisons would strengthen the performance claims.
  2. [§3.2] Notation for the epilogue primitive signatures (e.g., how reduction scope is parameterized) could be clarified with a small example in §3.2 to make the interface more accessible.

Simulated Author's Rebuttal

1 responses · 0 unresolved

We thank the referee for the careful review and constructive feedback. The observation about cross-tile reduction mechanics for norms is a valid point on presentation clarity, and we address it directly below.

read point-by-point responses
  1. Referee: [§3.2 and §4.2] §3.2 (Epilogue Primitives) and §4.2 (Norm Fusion): The reduction primitive is presented as operating on the GEMM output tile to enable on-chip LayerNorm/RMSNorm. However, standard GEMM tiling produces output tiles whose N-dimension (typically 128) is much smaller than the hidden dimension (e.g., 4096). The text does not specify how partial row reductions for mean/variance are accumulated across tiles using only registers or shared memory without intermediate global writes; if cross-tile communication requires extra memory traffic, the claimed elimination of separate memory-bound kernels for per-token norms does not hold.

    Authors: We agree that the manuscript does not provide an explicit description of cross-tile accumulation. In the CODA epilogue design, each output tile performs intra-tile reductions for partial sums and sums-of-squares using registers and shared memory. Across tiles spanning the full hidden dimension of a row, partial statistics are aggregated through a compact per-row auxiliary buffer in global memory via atomic additions executed within the same kernel launch. This keeps the dominant activation tensors on-chip during the GEMM mainloop and epilogue while avoiding separate full-tensor kernel launches. We acknowledge that this incurs limited additional global traffic proportional to tile count rather than tensor size; the net reduction in memory movement relative to unfused baselines remains substantial. We will revise §§3.2 and 4.2 to include a precise description, pseudocode, and a diagram of the multi-tile reduction flow. revision: yes

Circularity Check

0 steps flagged

No circularity: algebraic reparameterization is self-contained observation

full rationale

The paper's derivation rests on an algebraic observation that non-attention Transformer operators can be reparameterized to run as GEMM epilogues while the output tile stays on-chip. This is presented directly as an observation in the abstract and does not reduce to a self-definitional loop, a fitted parameter renamed as a prediction, or any load-bearing self-citation. The claim that the constrained epilogue primitives (scaling, reductions, pairwise transformations, accumulation) cover nearly all such operators is an independent expressiveness assertion rather than a tautology or imported uniqueness result. No equations or steps in the provided text equate the output to the input by construction. The approach is therefore self-contained against external benchmarks of kernel fusion and data-movement reduction.

Axiom & Free-Parameter Ledger

0 free parameters · 2 axioms · 1 invented entities

The paper introduces a new constrained programming interface whose correctness depends on standard assumptions about operator semantics and GPU memory behavior rather than new fitted constants or invented physical entities.

axioms (2)
  • domain assumption Algebraic reparameterization of normalization, activation, residual, and reduction operators preserves semantics when fused into a GEMM epilogue
    Invoked when stating that operators can execute while the GEMM output tile remains on chip
  • domain assumption The performance structure of expert-written GEMM mainloops is preserved when only the epilogue is modified
    Central to the claim that the abstraction maintains high performance
invented entities (1)
  • CODA epilogue primitives no independent evidence
    purpose: Composable operations for scaling, reductions, pairwise transformations, and accumulation
    New set of building blocks introduced to express non-attention computations inside the GEMM epilogue

pith-pipeline@v0.9.0 · 5755 in / 1547 out tokens · 69767 ms · 2026-05-21T07:17:17.395279+00:00 · methodology

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Lean theorems connected to this paper

Citations machine-checked in the Pith Canon. Every link opens the source theorem in the public Lean library.

What do these tags mean?
matches
The paper's claim is directly supported by a theorem in the formal canon.
supports
The theorem supports part of the paper's argument, but the paper may add assumptions or extra steps.
extends
The paper goes beyond the formal theorem; the theorem is a base layer rather than the whole result.
uses
The paper appears to rely on the theorem as machinery.
contradicts
The paper's claim conflicts with a theorem or certificate in the canon.
unclear
Pith found a possible connection, but the passage is too broad, indirect, or ambiguous to say the theorem truly supports the claim.

Reference graph

Works this paper leans on

53 extracted references · 53 canonical work pages · 3 internal anchors

  1. [1]

    Ansel, E

    J. Ansel, E. Yang, H. He, N. Gimelshein, A. Jain, M. V oznesensky, B. Bao, P. Bell, D. Berard, E. Burovski, et al. Pytorch 2: Faster machine learning through dynamic python bytecode transformation and graph compilation. InProceedings of the 29th ACM international conference on architectural support for programming languages and operating systems, volume 2...

  2. [2]

    T. Chen, T. Moreau, Z. Jiang, L. Zheng, E. Yan, H. Shen, M. Cowan, L. Wang, Y . Hu, L. Ceze, et al. {TVM}: An automated {End-to-End} optimizing compiler for deep learning. In13th USENIX Symposium on Operating Systems Design and Implementation (OSDI 18), pages 578–594, 2018

  3. [3]

    W. Chen, J. Zhu, Q. Fan, Y . Ma, and A. Zou. Cuda-llm: Llms can write efficient cuda kernels. arXiv preprint arXiv:2506.09092, 2025

  4. [4]

    Z. Chen, A. Kerr, R. Cai, J. Kosaian, H. Wu, Y . Ding, and Y . Xie. Evt: Accelerating deep learning training with epilogue visitor tree. InProceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 3, pages 301–316, 2024

  5. [5]

    The Llama 3 Herd of Models

    A. Grattafiori, A. Dubey, A. Jauhri, A. Pandey, A. Kadian, A. Al-Dahle, A. Letman, A. Mathur, A. Schelten, A. Vaughan, et al. The llama 3 herd of models.arXiv preprint arXiv:2407.21783, 2024

  6. [6]

    P.-L. Hsu, Y . Dai, V . Kothapalli, Q. Song, S. Tang, S. Zhu, S. Shimizu, S. Sahni, H. Ning, and Y . Chen. Liger kernel: Efficient triton kernels for llm training.arXiv preprint arXiv:2410.10989, 2024

  7. [7]

    Ivanov, N

    A. Ivanov, N. Dryden, T. Ben-Nun, S. Li, and T. Hoefler. Data movement is all you need: A case study on optimizing transformers.Proceedings of Machine Learning and Systems, 3:711–732, 2021. 10

  8. [8]

    Z. Jia, O. Padon, J. Thomas, T. Warszawski, M. Zaharia, and A. Aiken. Taso: optimizing deep learning computation with automatic generation of graph substitutions. InProceedings of the 27th ACM Symposium on Operating Systems Principles, pages 47–62, 2019

  9. [9]

    W. Kwon, Z. Li, S. Zhuang, Y . Sheng, L. Zheng, C. H. Yu, J. Gonzalez, H. Zhang, and I. Stoica. Efficient memory management for large language model serving with pagedattention. In Proceedings of the 29th symposium on operating systems principles, pages 611–626, 2023

  10. [10]

    R. T. Lange, Q. Sun, A. Prasad, M. Faldor, Y . Tang, and D. Ha. Towards robust agentic cuda kernel benchmarking, verification, and optimization.arXiv preprint arXiv:2509.14279, 2025

  11. [11]

    Liang, T

    W. Liang, T. Liu, L. Wright, W. Constable, A. Gu, C.-C. Huang, I. Zhang, W. Feng, H. Huang, J. Wang, et al. Torchtitan: One-stop pytorch native solution for production ready llm pre-training. arXiv preprint arXiv:2410.06511, 2024

  12. [12]

    KernelBench: Can LLMs Write Efficient GPU Kernels?

    A. Ouyang, S. Guo, S. Arora, A. L. Zhang, W. Hu, C. Ré, and A. Mirhoseini. Kernelbench: Can llms write efficient gpu kernels?arXiv preprint arXiv:2502.10517, 2025

  13. [13]

    Spector, J

    B. Spector, J. Juravsky, S. Sul, O. Dugan, D. Lim, D. Fu, S. Arora, and C. Ré. Look ma, no bubbles! designing a low-latency megakernel for llama-1b, 2025

  14. [14]

    B. F. Spector, S. Arora, A. Singhal, D. Y . Fu, and C. Ré. Thunderkittens: Simple, fast, and adorable ai kernels.arXiv preprint arXiv:2410.20399, 2024

  15. [15]

    J. Su, M. Ahmed, Y . Lu, S. Pan, W. Bo, and Y . Liu. Roformer: Enhanced transformer with rotary position embedding.Neurocomputing, 568:127063, 2024

  16. [16]

    S. Su, X. Sun, X. Li, A. Wang, J. Li, and C. Shum. Cuda-l2: Surpassing cublas performance for matrix multiplication through reinforcement learning.arXiv preprint arXiv:2512.02551, 2025

  17. [17]

    S. H. Sul, S. Arora, B. F. Spector, and C. Ré. Parallelkittens: Systematic and practical simplification of multi-gpu ai kernels.arXiv preprint arXiv:2511.13940, 2025

  18. [18]

    Thakkar, P

    V . Thakkar, P. Ramani, C. Cecka, A. Shivam, H. Lu, E. Yan, J. Kosaian, M. Hoemmen, H. Wu, A. Kerr, M. Nicely, D. Merrill, D. Blasig, A. Atluri, F. Qiao, P. Majcher, P. Springer, M. Hohnerbach, J. Wang, and M. Gupta. CUTLASS, Jan. 2023. URL https://github.com/ NVIDIA/cutlass

  19. [19]

    Tillet, H.-T

    P. Tillet, H.-T. Kung, and D. Cox. Triton: an intermediate language and compiler for tiled neural network computations. InProceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages, pages 10–19, 2019

  20. [20]

    L. Wang, Y . Cheng, Y . Shi, Z. Tang, Z. Mo, W. Xie, L. Ma, Y . Xia, J. Xue, F. Yang, et al. Tilelang: A composable tiled programming model for ai systems.arXiv preprint arXiv:2504.17577, 2025

  21. [21]

    Wijmans, B

    E. Wijmans, B. Huval, A. Hertzberg, V . Koltun, and P. Krähenbühl. Cut your losses in large- vocabulary language models. InInternational Conference on Learning Representations, 2025

  22. [22]

    M. Wu, X. Cheng, S. Liu, C. Shi, J. Ji, M. K. Ao, P. Velliengiri, X. Miao, O. Padon, and Z. Jia. Mirage: A {Multi-Level} superoptimizer for tensor programs. In19th USENIX Symposium on Operating Systems Design and Implementation (OSDI 25), pages 21–38, 2025

  23. [23]

    Z. Ye, L. Chen, R. Lai, W. Lin, Y . Zhang, S. Wang, T. Chen, B. Kasikci, V . Grover, A. Krishna- murthy, et al. Flashinfer: Efficient and customizable attention engine for llm inference serving. Proceedings of Machine Learning and Systems, 7, 2025

  24. [24]

    Learning to Discover at Test Time

    M. Yuksekgonul, D. Koceja, X. Li, F. Bianchi, J. McCaleb, X. Wang, J. Kautz, Y . Choi, J. Zou, C. Guestrin, et al. Learning to discover at test time.arXiv preprint arXiv:2601.16175, 2026

  25. [25]

    Zheng, L

    L. Zheng, L. Yin, Z. Xie, C. L. Sun, J. Huang, C. H. Yu, S. Cao, C. Kozyrakis, I. Stoica, J. E. Gonzalez, et al. Sglang: Efficient execution of structured language model programs.Advances in neural information processing systems, 37:62557–62583, 2024. 11 A Backward Pass A.1 Tile-wise Epilogue Partition the GEMM output h into tiles h[i,j]. A tile-wise epil...

  26. [26]

    , 25i n i t _ v a l u e = init_value , 26) 27 28 29class E V T R o w V e c M u l P o s t A c t ( E p i l o g u e V i s i t o r T r e e ) : 30" " " 31Loads a per - N row vector W ( cp . async to smem , then s2r ) , m u l t i p l i e s the 32a c c u m u l a t o r by W into a se par at e r egi st er tile , and stores that scaled 33tile to a side output m Po ...

  27. [27]

    __i ni t_ _ () 102self

    -> None : 101super () . __i ni t_ _ () 102self . arch = 90 103self . a c c _ d t y p e = a c c _ d t y p e 104self . p o s t _ a c t _ d t y p e = p o s t _ a c t _ d t y p e 105self . c o n t a i n e r _ d t y p e = p o s t _ a c t _ d t y p e 106self . t i l e _ s h a p e _ m n k = t i l e _ s h a p e _ m n k 107self . b u f f e r _ a l i g n _ b y t e ...

  28. [28]

    c o n s t _ e x p r ( e pi _ar gs

    -> E p i l o g u e P a r a m s : 117 118if cutlass . c o n s t _ e x p r ( e pi _ar gs . mPo st Ac t is not None ) : 119mP os tA ct = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s _ T e n s o r ( ep i_ arg s . mP os tA ct ) 120m i s c _ u t i l s . s t a t i c _ a s s e r t ( g e t _ d t y p e ( mP os tA ct ) is self . c o n t a i n e r _ d t y p ...

  29. [29]

    p r e p a r e _ t m a ( 127tma_op = " s2g " , 128e pi _t il e = epi_tile , 129e p i _ s t a g e = epi_stage , 130e p i _ t e n s o r = mPostAct , 131) 132 133if cutlass

    = e p i l o g u e _ u t i l s . p r e p a r e _ t m a ( 127tma_op = " s2g " , 128e pi _t il e = epi_tile , 129e p i _ s t a g e = epi_stage , 130e p i _ t e n s o r = mPostAct , 131) 132 133if cutlass . c o n s t _ e x p r ( e pi _ar gs . mRowVec is not None ) : 134m i s c _ u t i l s . s t a t i c _ a s s e r t ( e pi _a rgs . mPo st Ac t is not None ) 1...

  30. [30]

    -> None : 156cute . nvgpu . cpasync . p r e f e t c h _ d e s c r i p t o r ( e p i _ p a r a m s . e p i _ t m a _ a t o m ) 157 158@cute . jit 159def c o n s u m e r _ b e g i n ( 160self , 161t i l e d _ c o p y _ r 2 s : cute . TiledCopy , 162t i l e _ c o o r d _ m n k l : cute . Coord , 163tidx : cute . Int32 , 164t i l e d _ m m a : cute . TiledMma...

  31. [31]

    t i l e _ s h a p e _ m n k [0] 175tile_N = self

    -> E p i l o g u e T e n s o r s : 173 16 174tile_M = self . t i l e _ s h a p e _ m n k [0] 175tile_N = self . t i l e _ s h a p e _ m n k [1] 176m_idx , n_idx , _ , b a t c h _ i d x = t i l e _ c o o r d _ m n k l 177t h r _ c o p y _ r 2 s = t i l e d _ c o p y _ r 2 s . g e t _ s l i c e ( tidx ) 178 179# Side output ( PostAct ) TMA setup 180mPo st A...

  32. [32]

    jit 261def c o n s u m e r _ b e g i n _ l o o p ( 262self , 263e p i _ c o o r d : cute

    -> None : 258pass 259 260@cute . jit 261def c o n s u m e r _ b e g i n _ l o o p ( 262self , 263e p i _ c o o r d : cute . Coord , 264e p i _ p a r a m s : EpilogueParams , 265e p i _ t e n s o r s : EpilogueTensors , 266e p i _ p i p e l i n e s : E p i l o g u e P i p e l i n e s ,

  33. [33]

    c o n s t _ e x p r ( e p i _ t e n s o r s

    -> tuple [ E p i l o g u e T e n s o r s L o o p , E p i l o g u e P i p e l i n e s ]: 268 269if cutlass . c o n s t _ e x p r ( e p i _ t e n s o r s . t D s R o w V e c is not None ) : 270t D s R o w V e c = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s _ T e n s o r ( e p i _ t e n s o r s . t D s R o w V e c ) 271t D s R o w V e c _ c u r = c...

  34. [34]

    E p i l o g u e P i p e l i n e s () , 288) 289 290@cute

    , 287self . E p i l o g u e P i p e l i n e s () , 288) 289 290@cute . jit 291def c o n s u m e r _ v i s i t ( 292self , 293tRS_rD : cute . Tensor , 294s h a p e _ m n k : cute . Shape , 295e p i _ p a r a m s : EpilogueParams , 296e p i _ t e n s o r s _ l o o p : E p i l o g u e T e n s o r s L o o p ,

  35. [35]

    a l l o c a t e _ t e n s o r _ l i k e ( 300tensor = tRS_rD , 301me ms pa ce = " rmem " , 302s m e m _ a l l o c a t o r = None , 303dtype = self

    -> E p i l o g u e T e n s o r s L o o p : 298 299t R S _ r P o s t A c t = c r e a t i o n _ u t i l s . a l l o c a t e _ t e n s o r _ l i k e ( 300tensor = tRS_rD , 301me ms pa ce = " rmem " , 302s m e m _ a l l o c a t o r = None , 303dtype = self . acc_dtype , 304) 305if cutlass . c o n s t _ e x p r ( self . arch < 100) : 306if cutlass . c o n s t ...

  36. [36]

    t i l e d _ c o p y _ p o s t a c t _ r 2 s 340t R S _ r P o s t A c t = m i s c _ u t i l s

    -> None : 339t i l e d _ c o p y = e p i _ t e n s o r s _ l o o p . t i l e d _ c o p y _ p o s t a c t _ r 2 s 340t R S _ r P o s t A c t = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s _ T e n s o r ( e p i _ t e n s o r s _ l o o p . t R S _ r P o s t A c t ) 341t R S _ s P o s t A c t = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s _ ...

  37. [37]

    e p i _ t m a _ a t o m 355t D s P o s t A c t = m i s c _ u t i l s

    -> None : 354atom = e p i _ t e n s o r s _ l o o p . e p i _ t m a _ a t o m 355t D s P o s t A c t = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s _ T e n s o r ( e p i _ t e n s o r s _ l o o p . t D s P o s t A c t ) 356t D g P o s t A c t = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s _ T e n s o r ( e p i _ t e n s o r s _ l o o p ....

  38. [38]

    c o n s t _ e x p r ( e p i _ p a r a m s

    -> type [ E p i l o g u e S h a r e d S t o r a g e ]: 368 369if cutlass . c o n s t _ e x p r ( e p i _ p a r a m s . mP os tA ct is not None ) : 370p o s t _ a c t _ s m e m _ s i z e = cute . cosize ( e p i _ p a r a m s . e p i _ s m e m _ l a y o u t _ s t a g e d ) 371else : 372p o s t _ a c t _ s m e m _ s i z e = 0 373 374if cutlass . c o n s t _ ...

  39. [39]

    c o n s t _ e x p r ( e p i _ p a r a m s

    -> E p i l o g u e T e n s o r s S M e m : 400 401if cutlass . c o n s t _ e x p r ( e p i _ p a r a m s . mP os tA ct is not None ) : 402sP os tA ct = storage . sP os tA ct . g e t _ t e n s o r ( 403e p i _ p a r a m s . e p i _ s m e m _ l a y o u t _ s t a g e d . outer , 404swizzle = e p i _ p a r a m s . e p i _ s m e m _ l a y o u t _ s t a g e d ....

  40. [40]

    c o n s t _ e x p r ( e pi _ar gs

    -> tuple [ int , int , int ]: 427e p i _ s m e m _ b y t e s _ f i x e d = 0 428e p i _ s m e m _ b y t e s _ p e r _ s t a g e _ c s t = 0 429e p i _ s m e m _ b y t e s _ p e r _ s t a g e _ p l d = 0 430 431if cutlass . c o n s t _ e x p r ( e pi _ar gs . mPo st Ac t is not None ) : 432mP os tA ct = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s...

  41. [41]

    , E p i l o g u e V i s i t o r T r e e ] , 467E p i l o g u e V i s i t o r T r e e

    -> tuple [ 466C al lab le [... , E p i l o g u e V i s i t o r T r e e ] , 467E p i l o g u e V i s i t o r T r e e . E p i l o g u e A r g u m e n t s , 468dict , 469tuple , 470]: 471" " " Prepare ep il og ue for GEMM with residual , partial mean - of - squares , and 472fused per - N RMSNorm - weight scaling - mirrors t r a i n s t a t i o n ’s ‘ g e m m...

  42. [42]

    E V T R e s i d u a l : D = acc + C

  43. [43]

    E V T C o l B l o c k R e d u c t i o n S t o r e : S [m , nb ] = mean ( D [m , nb * bs :( nb +1) * bs ]^2)

  44. [44]

    tRS_rD is p r e s e r v e d 481so the main D output is also u ns cal ed

    E V T R o w V e c M u l P o s t A c t ( local ) : O [m , n ] = D [m , n ] * W [ n ] , side output via TMA 478 479The partial sum - of - squares is c om put ed on the * un sc ale d * D , so a d o w n s t r e a m 480rstd r e d u c t i o n sees the GEMM output before W is applied . tRS_rD is p r e s e r v e d 481so the main D output is also u ns cal ed . 482...

  45. [45]

    , 506E V T C o l B l o c k R e d u c t i o n S t o r e ( 507r e d u c t i o n _ o p = _ c r e a t e _ m e a n _ s q _ r e d u c t i o n _ o p ( 508e l e m e n t _ t y p e = acc_dtype , 509i n v _ b l o c k _ s i z e =1.0 / t i l e _ s h a p e _ m n k [1] ,

  46. [46]

    , 511t i l e _ s h a p e _ m n k = tile_shape_mnk ,

  47. [47]

    , 513E V T R o w V e c M u l P o s t A c t ( 514a c c _ d t y p e = acc_dtype , 515p o s t _ a c t _ d t y p e = post_act_dtype , 516t i l e _ s h a p e _ m n k = tile_shape_mnk , 517b u f f e r _ a l i g n _ b y t e s = b u f f e r _ a l i g n _ b y t e s ,

  48. [48]

    E p i l o g u e A r g u m e n t s ([ 522E V T R e s i d u a l

    , 519]) 520 521e pi _ar gs = EVTList . E p i l o g u e A r g u m e n t s ([ 522E V T R e s i d u a l . E p i l o g u e A r g u m e n t s ( 523mMatrix =C ,

  49. [49]

    E p i l o g u e A r g u m e n t s ( 526mColVec =S ,

    , 525E V T C o l B l o c k R e d u c t i o n S t o r e . E p i l o g u e A r g u m e n t s ( 526mColVec =S ,

  50. [50]

    E p i l o g u e A r g u m e n t s ( 529mP os tA ct =O , 530mRowVec =W ,

    , 528E V T R o w V e c M u l P o s t A c t . E p i l o g u e A r g u m e n t s ( 529mP os tA ct =O , 530mRowVec =W ,

  51. [51]

    dtype , 536S

    , 532]) 533 534e pi _ke ys = ( 535C . dtype , 536S . dtype , 537W . dtype , 538O . dtype , 539EVTResidual , 540E V T C o l B l o c k R e d u c t i o n S t o r e , 541E V T R o w V e c M u l P o s t A c t , 542) 543 544e pi _ou ts = {} 545 546return epi_cls , epi_args , epi_outs , ep i_k ey s Listing 2: Kernel Example. C Experiments C.1 List of Kernels We ...

  52. [52]

    FlashInfer 0.6.10.post1

  53. [53]

    QuACK Kernels 0.4.1 23