Low-precision Flash Attention fails due to similar low-rank attention representations combined with biased rounding errors that accumulate and corrupt weight updates; a minimal fix to reduce rounding bias stabilizes training.
Is flash attention stable?
5 Pith papers cite this work. Polarity classification is still indexing.
citation-role summary
citation-polarity summary
roles
background 1polarities
background 1representative citing papers
FlashAttention-3 achieves 1.5-2x speedup on H100 GPUs for attention, reaching 740 TFLOPs/s (75% utilization) in FP16 and near 1.2 PFLOPs/s in FP8 while cutting numerical error by 2.6x versus baseline FP8 attention.
Gated KalmaNet uses exact Kalman gain computation with adaptive gating and Chebyshev iteration to improve SSM performance on long-context tasks over prior approximations like DeltaNet.
PRISM introduces a probabilistic performance modeling framework that quantifies guarantees on training time for large-scale distributed systems under runtime variability.
The paper surveys techniques to speed up and reduce the resource needs of LLM inference, organized by data-level, model-level, and system-level changes, with comparative experiments on representative methods.
citing papers explorer
-
Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention
Low-precision Flash Attention fails due to similar low-rank attention representations combined with biased rounding errors that accumulate and corrupt weight updates; a minimal fix to reduce rounding bias stabilizes training.
-
FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
FlashAttention-3 achieves 1.5-2x speedup on H100 GPUs for attention, reaching 740 TFLOPs/s (75% utilization) in FP16 and near 1.2 PFLOPs/s in FP8 while cutting numerical error by 2.6x versus baseline FP8 attention.
-
Gated KalmaNet: A Fading Memory Layer Through Test-Time Ridge Regression
Gated KalmaNet uses exact Kalman gain computation with adaptive gating and Chebyshev iteration to improve SSM performance on long-context tasks over prior approximations like DeltaNet.
-
PRISM: Probabilistic Runtime Insights and Scalable Performance Modeling for Large-Scale Distributed Training
PRISM introduces a probabilistic performance modeling framework that quantifies guarantees on training time for large-scale distributed systems under runtime variability.
-
A Survey on Efficient Inference for Large Language Models
The paper surveys techniques to speed up and reduce the resource needs of LLM inference, organized by data-level, model-level, and system-level changes, with comparative experiments on representative methods.