pith. sign in

arxiv: 2605.19269 · v1 · pith:AIQS2Q4Inew · submitted 2026-05-19 · 💻 cs.LG

CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programs

Pith reviewed 2026-05-20 07:15 UTC · model grok-4.3

classification 💻 cs.LG
keywords transformerGEMMepilogueGPU kernelmemory-bound operatorskernel fusiontraining efficiency
0
0 comments X p. Extension
pith:AIQS2Q4I Add to your LaTeX paper What is a Pith Number?
\usepackage{pith}
\pithnumber{AIQS2Q4I}

Prints a linked pith:AIQS2Q4I badge after your title and writes the identifier into PDF metadata. Compiles on arXiv with no extra files. Learn more

The pith

Non-attention computations in Transformer blocks can run as epilogues to GEMM operations while the output tile stays on chip.

A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.

The paper establishes that many separate memory-bound operators in Transformers, including normalization, activations, residuals, and reductions, can be algebraically rewritten to execute as part of a GEMM computation's epilogue. This keeps the result in fast on-chip memory instead of moving large tensors through global memory. The abstraction keeps the expert-tuned GEMM main loop fixed and adds a small set of composable epilogue primitives for scaling, reductions, pairwise operations, and accumulation. These primitives prove expressive enough to cover nearly all non-attention work in both the forward and backward passes of a standard Transformer block. The approach yields high-performance kernels whether written by humans or generated by LLMs on representative workloads.

Core claim

CODA lets non-attention Transformer operators execute as GEMM-plus-epilogue programs by algebraically reparameterizing them so they run while the GEMM output tile remains on chip before any global write. The interface exposes a constrained but composable set of epilogue primitives that together handle scaling, reductions, pairwise transformations, and accumulation, preserving the performance characteristics of hand-written GEMM kernels.

What carries the argument

The GEMM-plus-epilogue program abstraction that fixes the GEMM mainloop and supplies composable epilogue primitives.

If this is right

  • Separate kernel launches and global-memory round trips for normalization, activations, and residuals become unnecessary.
  • End-to-end training time decreases because data movement shrinks while arithmetic intensity stays high.
  • Both manually authored and automatically generated kernels can match the speed of expert GEMM implementations.
  • The same interface works for forward and backward passes without changing the underlying GEMM schedule.

Where Pith is reading between the lines

These are editorial extensions of the paper, not claims the author makes directly.

  • If the epilogue interface proves general, frameworks could automatically rewrite entire blocks this way rather than relying on manual fusion.
  • The pattern may extend to attention layers once suitable reparameterizations for softmax or attention scores are identified.
  • Hardware vendors could add dedicated on-chip support for these specific epilogue primitives to further reduce latency.

Load-bearing premise

The small set of epilogue primitives remains expressive enough to represent nearly all non-attention operators through algebraic reparameterization without forcing the GEMM output tile off chip or losing performance.

What would settle it

An operator from the non-attention portion of a Transformer block whose algebraic form cannot be expressed with the given epilogue primitives or whose execution forces an early global write of the GEMM tile, producing measurable slowdown versus the unfused baseline.

Figures

Figures reproduced from arXiv: 2605.19269 by Arjun Menon, Driss Guessous, Han Guo, Jack Zhang, Tri Dao, Vijay Thakkar, Yoon Kim.

Figure 1
Figure 1. Figure 1: Runtime breakdown for LLaMA-3-style 1B model training on a single H100 using TorchTitan. LLM training has become just as much of a systems prob￾lem as a modeling one. FLOPs in modern Transformer￾based LLMs are dominated by matrix multiplications and at￾tention, whose kernels have been heavily optimized for Ten￾sor Core execution. Yet Transformers, and deep learning architectures more broadly, also contain … view at source ↗
Figure 2
Figure 2. Figure 2: Forward pass of a standard Transformer layer. The top row shows the canonical formulation, [PITH_FULL_IMAGE:figures/full_fig_p002_2.png] view at source ↗
Figure 3
Figure 3. Figure 3: A GEMM mainloop com￾putes output tiles; the epilogue trans￾forms each tile before the final global￾memory store. The epilogue is a natural place to implement fusions be￾cause the output of the matmul is already present on chip close to compute cores. Practical epilogues commonly per￾form scaling, bias addition, activations, residual updates, data type conversions, tile-wise reductions and other output elem… view at source ↗
Figure 4
Figure 4. Figure 4: GEMM-RMSNorm-GEMM reparameterization. We address the reduction by splitting it into two levels. The first GEMM epilogue computes tile-local partial reductions, and a small auxiliary ker￾nel reduces these partials across tiles to obtain r. Since the auxiliary ker￾nel reads a few partial values per tile rather than the full activation tensor, its memory traffic is much smaller than that of a standalone RMSNo… view at source ↗
Figure 5
Figure 5. Figure 5: Benchmarks Here r = 1/ p reduce(rb) + ϵ is computed by a small auxiliary reduction over the tile partials. This decomposition replaces a stan￾dalone RMSNorm kernel with tile-local epilogue work around the two GEMMs, plus a lightweight auxiliary reduction. In [PITH_FULL_IMAGE:figures/full_fig_p005_5.png] view at source ↗
Figure 6
Figure 6. Figure 6: Relative error. Numerics. The reparameterization changes where the RMSNorm scale is applied: the row-wise factor r is delayed from before the second GEMM to the second GEMM epilogue. We compare BF16 GEMM￾RMSNorm-GEMM outputs against an FP32 reference on Llama-3 8B layers. We report the errors of CODA and QuACK, on which our GEMM template is based, normalized by the error of the standard PyTorch path [PITH… view at source ↗
Figure 7
Figure 7. Figure 7: Pairwise activations op￾erate on local feature pairs in the GEMM epilogue. This form captures several operations in Transformer blocks: • RoPE rotates each feature pair and return two outputs; • SwiGLU combines gate and value stream into one output; • SwiGLU backward pass maps one incoming gradient into gra￾dients for both paired inputs. Pairwise activations couple neighboring feature lanes and may change … view at source ↗
Figure 8
Figure 8. Figure 8: Kernel-level speedups for representative GEMM-plus-epilogue primitives across [PITH_FULL_IMAGE:figures/full_fig_p006_8.png] view at source ↗
Figure 9
Figure 9. Figure 9: Forward and backward fusion for GEMM–epilogue blocks. Forward epilogues attach to the [PITH_FULL_IMAGE:figures/full_fig_p007_9.png] view at source ↗
Figure 10
Figure 10. Figure 10: Kernel-level speedups on reparameterized Transformer kernels relative to cuBLAS with [PITH_FULL_IMAGE:figures/full_fig_p008_10.png] view at source ↗
Figure 11
Figure 11. Figure 11: Block-level speedups for reparameterized Transformer kernel sequences, including [PITH_FULL_IMAGE:figures/full_fig_p009_11.png] view at source ↗
read the original abstract

Transformer training systems are built around dense linear algebra, yet a nontrivial fraction of end-to-end time is spent on surrounding memory-bound operators. Normalization, activations, residual updates, reductions, and related computations repeatedly move large intermediate tensors through global memory while performing little arithmetic, making data movement an increasingly important bottleneck in otherwise highly optimized training stacks. We introduce CODA, a GPU kernel abstraction that expresses these computations as GEMM-plus-epilogue programs. CODA is based on the observation that many Transformer operators exposed as separate framework kernels can be algebraically reparameterized to execute while a GEMM output tile remains on chip, before it is written to memory. The abstraction fixes the GEMM mainloop and exposes a small set of composable epilogue primitives for scaling, reductions, pairwise transformations, and accumulation. This constrained interface preserves the performance structure of expert-written GEMMs while remaining expressive enough to cover nearly all non-attention computation in the forward and backward pass of a standard Transformer block. Across representative Transformer workloads, both human- and LLM-authored CODA kernels achieve high performance, suggesting that GEMM-plus-epilogue programming offers a practical path toward combining framework-level productivity with hardware-level efficiency.

Editorial analysis

A structured set of objections, weighed in public.

Desk editor's note, referee report, simulated authors' rebuttal, and a circularity audit. Tearing a paper down is the easy half of reading it; the pith above is the substance, this is the friction.

Referee Report

2 major / 2 minor

Summary. The paper introduces CODA, a GPU kernel abstraction for expressing non-attention computations (normalization, activations, residuals, reductions) in Transformer forward and backward passes as GEMM-plus-epilogue programs. It rests on algebraic reparameterization to execute these operators while a GEMM output tile remains on-chip, using a fixed, composable set of epilogue primitives for scaling, reductions, pairwise transforms, and accumulation. The central claim is that this constrained interface covers nearly all non-attention work in a standard Transformer block while preserving the performance structure of expert-written GEMMs; both human- and LLM-authored kernels are reported to achieve high performance.

Significance. If the expressiveness and performance-preservation claims hold, CODA would offer a structured route to reduce memory-bandwidth bottlenecks in Transformer training by enabling on-chip fusion of memory-bound operators around highly optimized GEMM mainloops. This could improve end-to-end efficiency without requiring full custom kernel rewrites, bridging framework productivity and hardware-level performance. The reparameterization insight and constrained primitive set are potentially reusable beyond the specific Transformer setting.

major comments (2)
  1. [§3] §3 (Epilogue interface and reparameterization): The load-bearing claim that the constrained epilogue primitives (scaling, reductions, pairwise transforms, accumulation) suffice for nearly all non-attention operators—including backward-pass reductions such as LayerNorm gradients or fused residual-gradient paths—without forcing off-chip writes or GEMM-tile eviction is not yet substantiated. No explicit mapping or composition example is given for a reduction-heavy backward operator that would confirm the output tile can stay resident while the required reductions complete.
  2. [Abstract and §5] Abstract and §5 (Performance evaluation): The assertion that both human- and LLM-authored CODA kernels achieve high performance and preserve GEMM performance structure lacks any reported benchmarks, baselines, error bars, or implementation details. Without quantitative evidence that the epilogue overhead does not degrade the underlying GEMM throughput, the performance-preservation part of the central claim cannot be assessed.
minor comments (2)
  1. [Abstract] The abstract would be clearer if it listed the exact epilogue primitives and gave one concrete reparameterization example (e.g., for a residual-add or ReLU gradient) to illustrate the interface.
  2. [§2] Notation for the GEMM mainloop versus epilogue boundary should be introduced earlier and used consistently when describing on-chip residency.

Simulated Author's Rebuttal

2 responses · 0 unresolved

We thank the referee for their constructive and detailed review. The comments identify key areas where additional substantiation would strengthen the manuscript. We address each major comment below and describe the revisions we will make.

read point-by-point responses
  1. Referee: [§3] §3 (Epilogue interface and reparameterization): The load-bearing claim that the constrained epilogue primitives (scaling, reductions, pairwise transforms, accumulation) suffice for nearly all non-attention operators—including backward-pass reductions such as LayerNorm gradients or fused residual-gradient paths—without forcing off-chip writes or GEMM-tile eviction is not yet substantiated. No explicit mapping or composition example is given for a reduction-heavy backward operator that would confirm the output tile can stay resident while the required reductions complete.

    Authors: We appreciate the referee's emphasis on this point. Section 3 presents the epilogue primitives and argues for their composability based on algebraic reparameterization, but we agree that an explicit worked example for a backward-pass reduction operator is needed to fully substantiate the claim. In the revised manuscript we will add a detailed composition for the LayerNorm backward pass (including fused residual-gradient paths), showing step-by-step how the required reductions and accumulations are performed while the GEMM output tile remains on-chip. This addition will directly address the concern about off-chip writes or tile eviction. revision: yes

  2. Referee: [Abstract and §5] Abstract and §5 (Performance evaluation): The assertion that both human- and LLM-authored CODA kernels achieve high performance and preserve GEMM performance structure lacks any reported benchmarks, baselines, error bars, or implementation details. Without quantitative evidence that the epilogue overhead does not degrade the underlying GEMM throughput, the performance-preservation part of the central claim cannot be assessed.

    Authors: We acknowledge that the current manuscript states the performance outcome at a high level without the quantitative details required for rigorous evaluation. We will expand Section 5 to include comprehensive benchmark results for both human- and LLM-authored CODA kernels. These will report absolute and relative throughput numbers against standard cuBLAS/cuDNN baselines, include error bars from repeated runs, and provide implementation details (hardware platform, compiler flags, and measurement methodology). The new data will allow direct assessment of whether epilogue overhead preserves the performance structure of the underlying GEMM mainloop. revision: yes

Circularity Check

0 steps flagged

No significant circularity in CODA derivation

full rationale

The paper introduces CODA as a new abstraction based on an explicit observation that many Transformer operators can be algebraically reparameterized to run as epilogues while GEMM output tiles remain on-chip. The central claim of expressiveness for nearly all non-attention forward and backward work is presented as a property of the fixed mainloop plus composable primitives (scaling, reductions, pairwise transforms, accumulation), which is then validated through kernel implementations and benchmarks rather than being defined into existence or reduced to fitted inputs. No equations, uniqueness theorems, or load-bearing steps are shown to collapse by construction to prior self-citations or tautological redefinitions. The derivation chain is therefore self-contained as an engineering interface choice supported by empirical coverage demonstrations.

Axiom & Free-Parameter Ledger

0 free parameters · 1 axioms · 0 invented entities

The paper introduces a systems abstraction rather than new mathematical entities or fitted constants; it relies on standard GPU memory-hierarchy assumptions and the stated observation about operator reparameterization.

axioms (1)
  • domain assumption Many Transformer operators can be algebraically reparameterized to execute while a GEMM output tile remains on chip.
    This is the central observation stated in the abstract that enables the epilogue approach.

pith-pipeline@v0.9.0 · 5755 in / 1055 out tokens · 48376 ms · 2026-05-20T07:15:40.235743+00:00 · methodology

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Lean theorems connected to this paper

Citations machine-checked in the Pith Canon. Every link opens the source theorem in the public Lean library.

What do these tags mean?
matches
The paper's claim is directly supported by a theorem in the formal canon.
supports
The theorem supports part of the paper's argument, but the paper may add assumptions or extra steps.
extends
The paper goes beyond the formal theorem; the theorem is a base layer rather than the whole result.
uses
The paper appears to rely on the theorem as machinery.
contradicts
The paper's claim conflicts with a theorem or certificate in the canon.
unclear
Pith found a possible connection, but the passage is too broad, indirect, or ambiguous to say the theorem truly supports the claim.

Reference graph

Works this paper leans on

53 extracted references · 53 canonical work pages · 3 internal anchors

  1. [1]

    Ansel, E

    J. Ansel, E. Yang, H. He, N. Gimelshein, A. Jain, M. V oznesensky, B. Bao, P. Bell, D. Berard, E. Burovski, et al. Pytorch 2: Faster machine learning through dynamic python bytecode transformation and graph compilation. InProceedings of the 29th ACM international conference on architectural support for programming languages and operating systems, volume 2...

  2. [2]

    T. Chen, T. Moreau, Z. Jiang, L. Zheng, E. Yan, H. Shen, M. Cowan, L. Wang, Y . Hu, L. Ceze, et al. {TVM}: An automated {End-to-End} optimizing compiler for deep learning. In13th USENIX Symposium on Operating Systems Design and Implementation (OSDI 18), pages 578–594, 2018

  3. [3]

    W. Chen, J. Zhu, Q. Fan, Y . Ma, and A. Zou. Cuda-llm: Llms can write efficient cuda kernels. arXiv preprint arXiv:2506.09092, 2025

  4. [4]

    Z. Chen, A. Kerr, R. Cai, J. Kosaian, H. Wu, Y . Ding, and Y . Xie. Evt: Accelerating deep learning training with epilogue visitor tree. InProceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 3, pages 301–316, 2024

  5. [5]

    The Llama 3 Herd of Models

    A. Grattafiori, A. Dubey, A. Jauhri, A. Pandey, A. Kadian, A. Al-Dahle, A. Letman, A. Mathur, A. Schelten, A. Vaughan, et al. The llama 3 herd of models.arXiv preprint arXiv:2407.21783, 2024

  6. [6]

    P.-L. Hsu, Y . Dai, V . Kothapalli, Q. Song, S. Tang, S. Zhu, S. Shimizu, S. Sahni, H. Ning, and Y . Chen. Liger kernel: Efficient triton kernels for llm training.arXiv preprint arXiv:2410.10989, 2024

  7. [7]

    Ivanov, N

    A. Ivanov, N. Dryden, T. Ben-Nun, S. Li, and T. Hoefler. Data movement is all you need: A case study on optimizing transformers.Proceedings of Machine Learning and Systems, 3:711–732, 2021. 10

  8. [8]

    Z. Jia, O. Padon, J. Thomas, T. Warszawski, M. Zaharia, and A. Aiken. Taso: optimizing deep learning computation with automatic generation of graph substitutions. InProceedings of the 27th ACM Symposium on Operating Systems Principles, pages 47–62, 2019

  9. [9]

    W. Kwon, Z. Li, S. Zhuang, Y . Sheng, L. Zheng, C. H. Yu, J. Gonzalez, H. Zhang, and I. Stoica. Efficient memory management for large language model serving with pagedattention. In Proceedings of the 29th symposium on operating systems principles, pages 611–626, 2023

  10. [10]

    R. T. Lange, Q. Sun, A. Prasad, M. Faldor, Y . Tang, and D. Ha. Towards robust agentic cuda kernel benchmarking, verification, and optimization.arXiv preprint arXiv:2509.14279, 2025

  11. [11]

    Liang, T

    W. Liang, T. Liu, L. Wright, W. Constable, A. Gu, C.-C. Huang, I. Zhang, W. Feng, H. Huang, J. Wang, et al. Torchtitan: One-stop pytorch native solution for production ready llm pre-training. arXiv preprint arXiv:2410.06511, 2024

  12. [12]

    KernelBench: Can LLMs Write Efficient GPU Kernels?

    A. Ouyang, S. Guo, S. Arora, A. L. Zhang, W. Hu, C. Ré, and A. Mirhoseini. Kernelbench: Can llms write efficient gpu kernels?arXiv preprint arXiv:2502.10517, 2025

  13. [13]

    Spector, J

    B. Spector, J. Juravsky, S. Sul, O. Dugan, D. Lim, D. Fu, S. Arora, and C. Ré. Look ma, no bubbles! designing a low-latency megakernel for llama-1b, 2025

  14. [14]

    B. F. Spector, S. Arora, A. Singhal, D. Y . Fu, and C. Ré. Thunderkittens: Simple, fast, and adorable ai kernels.arXiv preprint arXiv:2410.20399, 2024

  15. [15]

    J. Su, M. Ahmed, Y . Lu, S. Pan, W. Bo, and Y . Liu. Roformer: Enhanced transformer with rotary position embedding.Neurocomputing, 568:127063, 2024

  16. [16]

    S. Su, X. Sun, X. Li, A. Wang, J. Li, and C. Shum. Cuda-l2: Surpassing cublas performance for matrix multiplication through reinforcement learning.arXiv preprint arXiv:2512.02551, 2025

  17. [17]

    S. H. Sul, S. Arora, B. F. Spector, and C. Ré. Parallelkittens: Systematic and practical simplification of multi-gpu ai kernels.arXiv preprint arXiv:2511.13940, 2025

  18. [18]

    Thakkar, P

    V . Thakkar, P. Ramani, C. Cecka, A. Shivam, H. Lu, E. Yan, J. Kosaian, M. Hoemmen, H. Wu, A. Kerr, M. Nicely, D. Merrill, D. Blasig, A. Atluri, F. Qiao, P. Majcher, P. Springer, M. Hohnerbach, J. Wang, and M. Gupta. CUTLASS, Jan. 2023. URL https://github.com/ NVIDIA/cutlass

  19. [19]

    Tillet, H.-T

    P. Tillet, H.-T. Kung, and D. Cox. Triton: an intermediate language and compiler for tiled neural network computations. InProceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages, pages 10–19, 2019

  20. [20]

    L. Wang, Y . Cheng, Y . Shi, Z. Tang, Z. Mo, W. Xie, L. Ma, Y . Xia, J. Xue, F. Yang, et al. Tilelang: A composable tiled programming model for ai systems.arXiv preprint arXiv:2504.17577, 2025

  21. [21]

    Wijmans, B

    E. Wijmans, B. Huval, A. Hertzberg, V . Koltun, and P. Krähenbühl. Cut your losses in large- vocabulary language models. InInternational Conference on Learning Representations, 2025

  22. [22]

    M. Wu, X. Cheng, S. Liu, C. Shi, J. Ji, M. K. Ao, P. Velliengiri, X. Miao, O. Padon, and Z. Jia. Mirage: A {Multi-Level} superoptimizer for tensor programs. In19th USENIX Symposium on Operating Systems Design and Implementation (OSDI 25), pages 21–38, 2025

  23. [23]

    Z. Ye, L. Chen, R. Lai, W. Lin, Y . Zhang, S. Wang, T. Chen, B. Kasikci, V . Grover, A. Krishna- murthy, et al. Flashinfer: Efficient and customizable attention engine for llm inference serving. Proceedings of Machine Learning and Systems, 7, 2025

  24. [24]

    Learning to Discover at Test Time

    M. Yuksekgonul, D. Koceja, X. Li, F. Bianchi, J. McCaleb, X. Wang, J. Kautz, Y . Choi, J. Zou, C. Guestrin, et al. Learning to discover at test time.arXiv preprint arXiv:2601.16175, 2026

  25. [25]

    Zheng, L

    L. Zheng, L. Yin, Z. Xie, C. L. Sun, J. Huang, C. H. Yu, S. Cao, C. Kozyrakis, I. Stoica, J. E. Gonzalez, et al. Sglang: Efficient execution of structured language model programs.Advances in neural information processing systems, 37:62557–62583, 2024. 11 A Backward Pass A.1 Tile-wise Epilogue Partition the GEMM output h into tiles h[i,j]. A tile-wise epil...

  26. [26]

    , 25i n i t _ v a l u e = init_value , 26) 27 28 29class E V T R o w V e c M u l P o s t A c t ( E p i l o g u e V i s i t o r T r e e ) : 30" " " 31Loads a per - N row vector W ( cp . async to smem , then s2r ) , m u l t i p l i e s the 32a c c u m u l a t o r by W into a se par at e r egi st er tile , and stores that scaled 33tile to a side output m Po ...

  27. [27]

    __i ni t_ _ () 102self

    -> None : 101super () . __i ni t_ _ () 102self . arch = 90 103self . a c c _ d t y p e = a c c _ d t y p e 104self . p o s t _ a c t _ d t y p e = p o s t _ a c t _ d t y p e 105self . c o n t a i n e r _ d t y p e = p o s t _ a c t _ d t y p e 106self . t i l e _ s h a p e _ m n k = t i l e _ s h a p e _ m n k 107self . b u f f e r _ a l i g n _ b y t e ...

  28. [28]

    c o n s t _ e x p r ( e pi _ar gs

    -> E p i l o g u e P a r a m s : 117 118if cutlass . c o n s t _ e x p r ( e pi _ar gs . mPo st Ac t is not None ) : 119mP os tA ct = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s _ T e n s o r ( ep i_ arg s . mP os tA ct ) 120m i s c _ u t i l s . s t a t i c _ a s s e r t ( g e t _ d t y p e ( mP os tA ct ) is self . c o n t a i n e r _ d t y p ...

  29. [29]

    p r e p a r e _ t m a ( 127tma_op = " s2g " , 128e pi _t il e = epi_tile , 129e p i _ s t a g e = epi_stage , 130e p i _ t e n s o r = mPostAct , 131) 132 133if cutlass

    = e p i l o g u e _ u t i l s . p r e p a r e _ t m a ( 127tma_op = " s2g " , 128e pi _t il e = epi_tile , 129e p i _ s t a g e = epi_stage , 130e p i _ t e n s o r = mPostAct , 131) 132 133if cutlass . c o n s t _ e x p r ( e pi _ar gs . mRowVec is not None ) : 134m i s c _ u t i l s . s t a t i c _ a s s e r t ( e pi _a rgs . mPo st Ac t is not None ) 1...

  30. [30]

    -> None : 156cute . nvgpu . cpasync . p r e f e t c h _ d e s c r i p t o r ( e p i _ p a r a m s . e p i _ t m a _ a t o m ) 157 158@cute . jit 159def c o n s u m e r _ b e g i n ( 160self , 161t i l e d _ c o p y _ r 2 s : cute . TiledCopy , 162t i l e _ c o o r d _ m n k l : cute . Coord , 163tidx : cute . Int32 , 164t i l e d _ m m a : cute . TiledMma...

  31. [31]

    t i l e _ s h a p e _ m n k [0] 175tile_N = self

    -> E p i l o g u e T e n s o r s : 173 16 174tile_M = self . t i l e _ s h a p e _ m n k [0] 175tile_N = self . t i l e _ s h a p e _ m n k [1] 176m_idx , n_idx , _ , b a t c h _ i d x = t i l e _ c o o r d _ m n k l 177t h r _ c o p y _ r 2 s = t i l e d _ c o p y _ r 2 s . g e t _ s l i c e ( tidx ) 178 179# Side output ( PostAct ) TMA setup 180mPo st A...

  32. [32]

    jit 261def c o n s u m e r _ b e g i n _ l o o p ( 262self , 263e p i _ c o o r d : cute

    -> None : 258pass 259 260@cute . jit 261def c o n s u m e r _ b e g i n _ l o o p ( 262self , 263e p i _ c o o r d : cute . Coord , 264e p i _ p a r a m s : EpilogueParams , 265e p i _ t e n s o r s : EpilogueTensors , 266e p i _ p i p e l i n e s : E p i l o g u e P i p e l i n e s ,

  33. [33]

    c o n s t _ e x p r ( e p i _ t e n s o r s

    -> tuple [ E p i l o g u e T e n s o r s L o o p , E p i l o g u e P i p e l i n e s ]: 268 269if cutlass . c o n s t _ e x p r ( e p i _ t e n s o r s . t D s R o w V e c is not None ) : 270t D s R o w V e c = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s _ T e n s o r ( e p i _ t e n s o r s . t D s R o w V e c ) 271t D s R o w V e c _ c u r = c...

  34. [34]

    E p i l o g u e P i p e l i n e s () , 288) 289 290@cute

    , 287self . E p i l o g u e P i p e l i n e s () , 288) 289 290@cute . jit 291def c o n s u m e r _ v i s i t ( 292self , 293tRS_rD : cute . Tensor , 294s h a p e _ m n k : cute . Shape , 295e p i _ p a r a m s : EpilogueParams , 296e p i _ t e n s o r s _ l o o p : E p i l o g u e T e n s o r s L o o p ,

  35. [35]

    a l l o c a t e _ t e n s o r _ l i k e ( 300tensor = tRS_rD , 301me ms pa ce = " rmem " , 302s m e m _ a l l o c a t o r = None , 303dtype = self

    -> E p i l o g u e T e n s o r s L o o p : 298 299t R S _ r P o s t A c t = c r e a t i o n _ u t i l s . a l l o c a t e _ t e n s o r _ l i k e ( 300tensor = tRS_rD , 301me ms pa ce = " rmem " , 302s m e m _ a l l o c a t o r = None , 303dtype = self . acc_dtype , 304) 305if cutlass . c o n s t _ e x p r ( self . arch < 100) : 306if cutlass . c o n s t ...

  36. [36]

    t i l e d _ c o p y _ p o s t a c t _ r 2 s 340t R S _ r P o s t A c t = m i s c _ u t i l s

    -> None : 339t i l e d _ c o p y = e p i _ t e n s o r s _ l o o p . t i l e d _ c o p y _ p o s t a c t _ r 2 s 340t R S _ r P o s t A c t = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s _ T e n s o r ( e p i _ t e n s o r s _ l o o p . t R S _ r P o s t A c t ) 341t R S _ s P o s t A c t = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s _ ...

  37. [37]

    e p i _ t m a _ a t o m 355t D s P o s t A c t = m i s c _ u t i l s

    -> None : 354atom = e p i _ t e n s o r s _ l o o p . e p i _ t m a _ a t o m 355t D s P o s t A c t = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s _ T e n s o r ( e p i _ t e n s o r s _ l o o p . t D s P o s t A c t ) 356t D g P o s t A c t = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s _ T e n s o r ( e p i _ t e n s o r s _ l o o p ....

  38. [38]

    c o n s t _ e x p r ( e p i _ p a r a m s

    -> type [ E p i l o g u e S h a r e d S t o r a g e ]: 368 369if cutlass . c o n s t _ e x p r ( e p i _ p a r a m s . mP os tA ct is not None ) : 370p o s t _ a c t _ s m e m _ s i z e = cute . cosize ( e p i _ p a r a m s . e p i _ s m e m _ l a y o u t _ s t a g e d ) 371else : 372p o s t _ a c t _ s m e m _ s i z e = 0 373 374if cutlass . c o n s t _ ...

  39. [39]

    c o n s t _ e x p r ( e p i _ p a r a m s

    -> E p i l o g u e T e n s o r s S M e m : 400 401if cutlass . c o n s t _ e x p r ( e p i _ p a r a m s . mP os tA ct is not None ) : 402sP os tA ct = storage . sP os tA ct . g e t _ t e n s o r ( 403e p i _ p a r a m s . e p i _ s m e m _ l a y o u t _ s t a g e d . outer , 404swizzle = e p i _ p a r a m s . e p i _ s m e m _ l a y o u t _ s t a g e d ....

  40. [40]

    c o n s t _ e x p r ( e pi _ar gs

    -> tuple [ int , int , int ]: 427e p i _ s m e m _ b y t e s _ f i x e d = 0 428e p i _ s m e m _ b y t e s _ p e r _ s t a g e _ c s t = 0 429e p i _ s m e m _ b y t e s _ p e r _ s t a g e _ p l d = 0 430 431if cutlass . c o n s t _ e x p r ( e pi _ar gs . mPo st Ac t is not None ) : 432mP os tA ct = m i s c _ u t i l s . s t a t i c _ a s s e r t _ i s...

  41. [41]

    , E p i l o g u e V i s i t o r T r e e ] , 467E p i l o g u e V i s i t o r T r e e

    -> tuple [ 466C al lab le [... , E p i l o g u e V i s i t o r T r e e ] , 467E p i l o g u e V i s i t o r T r e e . E p i l o g u e A r g u m e n t s , 468dict , 469tuple , 470]: 471" " " Prepare ep il og ue for GEMM with residual , partial mean - of - squares , and 472fused per - N RMSNorm - weight scaling - mirrors t r a i n s t a t i o n ’s ‘ g e m m...

  42. [42]

    E V T R e s i d u a l : D = acc + C

  43. [43]

    E V T C o l B l o c k R e d u c t i o n S t o r e : S [m , nb ] = mean ( D [m , nb * bs :( nb +1) * bs ]^2)

  44. [44]

    tRS_rD is p r e s e r v e d 481so the main D output is also u ns cal ed

    E V T R o w V e c M u l P o s t A c t ( local ) : O [m , n ] = D [m , n ] * W [ n ] , side output via TMA 478 479The partial sum - of - squares is c om put ed on the * un sc ale d * D , so a d o w n s t r e a m 480rstd r e d u c t i o n sees the GEMM output before W is applied . tRS_rD is p r e s e r v e d 481so the main D output is also u ns cal ed . 482...

  45. [45]

    , 506E V T C o l B l o c k R e d u c t i o n S t o r e ( 507r e d u c t i o n _ o p = _ c r e a t e _ m e a n _ s q _ r e d u c t i o n _ o p ( 508e l e m e n t _ t y p e = acc_dtype , 509i n v _ b l o c k _ s i z e =1.0 / t i l e _ s h a p e _ m n k [1] ,

  46. [46]

    , 511t i l e _ s h a p e _ m n k = tile_shape_mnk ,

  47. [47]

    , 513E V T R o w V e c M u l P o s t A c t ( 514a c c _ d t y p e = acc_dtype , 515p o s t _ a c t _ d t y p e = post_act_dtype , 516t i l e _ s h a p e _ m n k = tile_shape_mnk , 517b u f f e r _ a l i g n _ b y t e s = b u f f e r _ a l i g n _ b y t e s ,

  48. [48]

    E p i l o g u e A r g u m e n t s ([ 522E V T R e s i d u a l

    , 519]) 520 521e pi _ar gs = EVTList . E p i l o g u e A r g u m e n t s ([ 522E V T R e s i d u a l . E p i l o g u e A r g u m e n t s ( 523mMatrix =C ,

  49. [49]

    E p i l o g u e A r g u m e n t s ( 526mColVec =S ,

    , 525E V T C o l B l o c k R e d u c t i o n S t o r e . E p i l o g u e A r g u m e n t s ( 526mColVec =S ,

  50. [50]

    E p i l o g u e A r g u m e n t s ( 529mP os tA ct =O , 530mRowVec =W ,

    , 528E V T R o w V e c M u l P o s t A c t . E p i l o g u e A r g u m e n t s ( 529mP os tA ct =O , 530mRowVec =W ,

  51. [51]

    dtype , 536S

    , 532]) 533 534e pi _ke ys = ( 535C . dtype , 536S . dtype , 537W . dtype , 538O . dtype , 539EVTResidual , 540E V T C o l B l o c k R e d u c t i o n S t o r e , 541E V T R o w V e c M u l P o s t A c t , 542) 543 544e pi _ou ts = {} 545 546return epi_cls , epi_args , epi_outs , ep i_k ey s Listing 2: Kernel Example. C Experiments C.1 List of Kernels We ...

  52. [52]

    FlashInfer 0.6.10.post1

  53. [53]

    QuACK Kernels 0.4.1 23