Training Infinitely Deep and Wide Transformers
Pith reviewed 2026-05-19 22:08 UTC · model grok-4.3
The pith
Gradient flow in the conditional Wasserstein metric converges to global minima for infinitely deep and wide transformers when the initial loss is small enough and the attention NTK is injective.
A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.
Core claim
Under the Neural Tangent Kernel injectivity assumption—equivalent to linear independence of log-sum-exp functions modulo affine functions—gradient flow in the conditional Wasserstein metric converges to global minima of the training risk for mean-field transformers whenever the initial loss is sufficiently small.
What carries the argument
The conditional Wasserstein gradient flow of the training risk, derived from adjoint sensitivity analysis on the forward and backward ODEs that govern the evolution of token distributions and attention parameters.
If this is right
- The optimization landscape of the infinite transformer contains no spurious local minima once the NTK injectivity condition holds.
- Gradient-based training reaches a global minimum from any sufficiently low initial loss value.
- Token distributions evolve according to unique flow maps that satisfy well-posed ODEs in an appropriate function space.
- The injectivity condition is satisfied by discrete distributions, uniform distributions, and Gaussian mixtures.
Where Pith is reading between the lines
- Large finite transformers may inherit approximate global convergence when their width and depth are big enough for the mean-field description to be accurate.
- The linear-independence condition on log-sum-exp functions could be checked numerically on real data sets to predict whether global convergence is likely.
- The same conditional-Wasserstein-flow approach may be adaptable to other attention-based or coupling-based architectures.
Load-bearing premise
The mean-field limit together with the NTK injectivity condition continue to describe the behavior of the large but finite transformers that are actually trained.
What would settle it
A concrete token distribution for which the associated log-sum-exp functions are linearly dependent modulo affine functions, together with a numerical check that the resulting gradient flow reaches a local minimum rather than a global one even from small initial loss.
read the original abstract
Transformers have become the dominant architecture in modern machine learning, yet the theoretical understanding of their training dynamics remains limited. This paper develops a rigorous mathematical framework for analyzing gradient-based training of transformers in the mean-field regime, where both the depth (number of layers) and width (number of attention heads) tend to infinity. While ResNet training can be understood as controlling a neural ODE, transformer training corresponds to controlling a neural PDE, due to the coupling of multiple token distributions through the attention mechanism. Our mean-field model features two types of measure representations: token distributions evolving through layers and attention parameters at each layer. We establish well-posedness of the forward pass through infinitely deep transformers, characterizing token evolution via flow maps that satisfy ODEs in function spaces. Using adjoint sensitivity analysis, we derive an explicit formula for the conditional Wasserstein gradient of the training risk, involving adjoint variables governed by backward ODEs. We prove the existence and uniqueness of gradient flow curves in the conditional Wasserstein metric space, establishing a rigorous foundation for gradient-based transformer training. A key technical contribution is providing necessary and sufficient conditions for injectivity of the Neural Tangent Kernel (NTK) for attention mechanisms: we show that NTK injectivity is equivalent to linear independence of log-sum-exp functions modulo affine functions, a condition satisfied by diverse token distributions, including discrete distributions, uniform distributions, and Gaussian mixtures. Under this NTK injectivity assumption, we prove that gradient flow converges to global minima when the initial loss is sufficiently small, eliminating spurious local minima from the optimization landscape.
Editorial analysis
A structured set of objections, weighed in public.
Referee Report
Summary. The paper develops a mean-field analysis for training transformers with both depth and width tending to infinity. It models the system as a neural PDE coupling token distributions and attention parameters, proves well-posedness of the infinite-depth forward pass via flow maps and function-space ODEs, derives the training gradient explicitly using adjoint sensitivity in the conditional Wasserstein metric, establishes existence and uniqueness of gradient-flow trajectories, characterizes NTK injectivity as linear independence of log-sum-exp functions modulo affine maps (satisfied by discrete, uniform, and Gaussian-mixture token distributions), and proves that, under this injectivity assumption, gradient flow converges to global minima whenever the initial loss is sufficiently small.
Significance. If the derivations hold, the work supplies a rigorous PDE-theoretic foundation for transformer optimization that parallels neural-ODE analyses of ResNets while accounting for the attention-induced coupling of token measures. The explicit adjoint formula for the conditional Wasserstein gradient and the necessary-and-sufficient NTK-injectivity characterization are technically substantive contributions. The global-convergence result under a verifiable injectivity condition would be a notable advance for understanding the absence of spurious local minima in the infinite-width/depth regime.
major comments (2)
- [§5 (global convergence theorem)] The global-convergence statement (presumably Theorem 5.3 or the main result in §5) assumes NTK injectivity persists along the entire gradient-flow trajectory. The manuscript must verify or prove that the linear-independence condition on log-sum-exp functions is preserved (or at least not violated) by the evolving token distributions under the forward PDE; without this invariance the assumption cannot be maintained from a small initial loss to the global minimum.
- [§2–3 (mean-field limit and passage from finite to infinite)] No quantitative approximation rates or stability estimates are supplied that control the distance between the finite-width/depth transformer and its mean-field PDE limit. Because the central claim concerns practical gradient-based training, the absence of such bounds leaves open whether discretization or finite-head errors can reintroduce critical points before the loss becomes small (as noted in the stress-test concern).
minor comments (2)
- [§2] Notation for the conditional Wasserstein metric and the precise coupling between token measures and attention parameters should be introduced with a short self-contained paragraph early in §2 to aid readers unfamiliar with Wasserstein geometry on product spaces.
- [§3.1] The regularity assumptions (e.g., Lipschitz constants, moment bounds) required for well-posedness of the function-space ODEs and for the adjoint equations should be stated explicitly rather than left implicit in the existence proofs.
Simulated Author's Rebuttal
We thank the referee for the careful reading of our manuscript and the constructive comments. We address the major comments point by point below.
read point-by-point responses
-
Referee: [§5 (global convergence theorem)] The global-convergence statement (presumably Theorem 5.3 or the main result in §5) assumes NTK injectivity persists along the entire gradient-flow trajectory. The manuscript must verify or prove that the linear-independence condition on log-sum-exp functions is preserved (or at least not violated) by the evolving token distributions under the forward PDE; without this invariance the assumption cannot be maintained from a small initial loss to the global minimum.
Authors: We thank the referee for highlighting this subtlety. The global convergence result in §5 is established under the assumption that the NTK injectivity condition (linear independence of log-sum-exp functions modulo affine maps) holds for the token distributions. In the revised manuscript we will add a remark immediately following the statement of the theorem clarifying that the assumption is required to hold along the entire trajectory. We will further include a short continuity argument showing that, because the forward evolution is realized by smooth flow maps that act continuously on the space of measures in the weak topology, and because the set of measures satisfying the independence condition is open, the property is preserved whenever the initial loss is sufficiently small. This constitutes a partial revision. revision: partial
-
Referee: [§2–3 (mean-field limit and passage from finite to infinite)] No quantitative approximation rates or stability estimates are supplied that control the distance between the finite-width/depth transformer and its mean-field PDE limit. Because the central claim concerns practical gradient-based training, the absence of such bounds leaves open whether discretization or finite-head errors can reintroduce critical points before the loss becomes small (as noted in the stress-test concern).
Authors: We agree that quantitative approximation rates between finite transformers and the mean-field PDE limit would strengthen the link to practical training. However, the present work concentrates on the rigorous analysis of the infinite-depth, infinite-width regime itself: well-posedness of the neural PDE, derivation of the conditional Wasserstein gradient via adjoint sensitivity, and global convergence under the NTK injectivity condition. Establishing explicit rates would require additional stability estimates for the coupled attention mechanism in the conditional Wasserstein metric and is a substantial undertaking that lies outside the current scope. In the revised manuscript we will insert a brief paragraph in the introduction and a corresponding note in the conclusion acknowledging this limitation and identifying the derivation of approximation rates as an important direction for future research. The core claims of the paper remain unaffected. revision: no
Circularity Check
No circularity: derivations use standard adjoint and Wasserstein tools with independent NTK characterization
full rationale
The paper's core chain proceeds from well-posedness of the infinite-depth forward pass via flow maps satisfying ODEs in function spaces, through adjoint sensitivity to obtain the explicit conditional Wasserstein gradient, to existence/uniqueness of gradient-flow curves in that metric, and finally to global convergence under the NTK-injectivity assumption when initial loss is small. The injectivity result is stated as an independent necessary-and-sufficient characterization (linear independence of log-sum-exp functions modulo affine maps) verified on concrete distributions; it is not obtained by fitting inside the paper nor by self-referential definition. All steps rely on classical PDE/OT techniques whose validity does not presuppose the target convergence statement. No load-bearing self-citation, ansatz smuggling, or renaming of known empirical patterns occurs. The derivation is therefore self-contained.
Axiom & Free-Parameter Ledger
axioms (2)
- standard math Well-posedness of ODEs in appropriate function spaces for the token flow maps
- domain assumption Existence and uniqueness of gradient flows in the conditional Wasserstein metric
invented entities (1)
-
Conditional Wasserstein metric on the space of attention parameters coupled to token distributions
no independent evidence
Reference graph
Works this paper leans on
-
[1]
Transformers learn to imple- ment preconditioned gradient descent for in-context learning
Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, and Suvrit Sra. Transformers learn to imple- ment preconditioned gradient descent for in-context learning. InAdvances in Neural Information Processing Systems, volume 36, 2023
work page 2023
-
[2]
What learning algorithm is in-context learning? investigations with linear models
Ekin Aky¨ urek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, and Denny Zhou. What learning algorithm is in-context learning? investigations with linear models. InInternational Conference on Learning Representations, 2023
work page 2023
-
[3]
A convergence theory for deep learning via over-parameterization
Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via over-parameterization. InInternational Conference on Machine Learning, pages 242–252. PMLR, 2019
work page 2019
-
[4]
Transport equation and Cauchy problem for non-smooth vector fields
Luigi Ambrosio. Transport equation and Cauchy problem for non-smooth vector fields. InCalculus of variations and nonlinear partial differential equations, pages 1–41. Springer, 2008
work page 2008
-
[5]
A user’s guide to optimal transport
Luigi Ambrosio and Nicola Gigli. A user’s guide to optimal transport. In Benedetto Piccoli and Michel Rascle, editors,Modelling and Optimisation of Flows on Networks, volume 2062 ofLecture Notes in Mathematics, pages 1–155. Springer, Berlin, Heidelberg, 2012
work page 2062
-
[6]
Lectures in Mathematics ETH Z¨ urich
Luigi Ambrosio, Nicola Gigli, and Giuseppe Savar´ e.Gradient Flows: In Metric Spaces and in the Space of Probability Measures. Lectures in Mathematics ETH Z¨ urich. Birkh¨ auser, Basel, second edition, 2008
work page 2008
-
[7]
Vladimir Arnold. Sur la g´ eom´ etrie diff´ erentielle des groupes de lie de dimension infinie et ses applications ` a l’hydrodynamique des fluides parfaits.Annales de l’Institut Fourier, 16(1):319– 361, 1966
work page 1966
-
[8]
Hedy Attouch, Giuseppe Buttazzo, and G´ erard Michaille.Variational analysis in Sobolev and BV spaces: applications to PDEs and optimization. SIAM, 2014
work page 2014
-
[9]
Integral manifolds for Carath´ eodory type differential equa- tions in Banach spaces
Bernd Aulbach and Thomas Wanner. Integral manifolds for Carath´ eodory type differential equa- tions in Banach spaces. In Bernd Aulbach and Fritz Colonius, editors,Six Lectures on Dynamical Systems, pages 45–119. World Scientific, Singapore, 1996
work page 1996
-
[10]
Rapha¨ el Barboni, Gabriel Peyr´ e, and Fran¸ cois-Xavier Vialard. Understanding the training of infinitely deep and wide ResNets with conditional optimal transport.Communications on Pure and Applied Mathematics, 78(11):2149–2205, 2025
work page 2025
-
[11]
How smooth is attention? InInternational Conference on Machine Learning, pages 5817–5840
Val´ erie Castin, Pierre Ablin, and Gabriel Peyr´ e. How smooth is attention? InInternational Conference on Machine Learning, pages 5817–5840. PMLR, 2024
work page 2024
-
[12]
Convergence of gradient descent for deep neural networks
Sourav Chatterjee. Convergence of gradient descent for deep neural networks.arXiv preprint arXiv:2203.16462, 2022
-
[13]
Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. Neural ordinary differential equations.Advances in Neural Information Processing Systems, 31, 2018. 31
work page 2018
-
[14]
L´ ena¨ ıc Chizat and Francis Bach. On the global convergence of gradient descent for over- parameterized models using optimal transport.Advances in Neural Information Processing Sys- tems, 31:3036–3046, 2018
work page 2018
-
[15]
New problems on minimizing movements
Ennio De Giorgi. New problems on minimizing movements. In Claudio Baiocchi and Jacques-Louis Lions, editors,Boundary Value Problems for Partial Differential Equations and Applications, volume 29 ofRMA Research Notes in Applied Mathematics, pages 81–98. Masson, Paris, 1993
work page 1993
-
[16]
Springer, Berlin, Heidelberg, 1977
Klaus Deimling.Ordinary Differential Equations in Banach Spaces, volume 596 ofLecture Notes in Mathematics. Springer, Berlin, Heidelberg, 1977
work page 1977
-
[17]
Lorenzo Dello Schiavo, Jan Maas, and Francesco Pedrotti. Local conditions for global convergence of gradient flows and proximal point sequences in metric spaces.Transactions of the American Mathematical Society, 377(06):3779–3804, 2024
work page 2024
-
[18]
Uhl, Jr.Vector Measures, volume 15 ofMathematical Surveys
Joe Diestel and Jerry J. Uhl, Jr.Vector Measures, volume 15 ofMathematical Surveys. American Mathematical Society, Providence, RI, 1977
work page 1977
-
[19]
Zhiyan Ding, Shi Chen, Qin Li, and Stephen Wright. On the global convergence of gradient descent for multi-layer ResNets in the mean-field regime.arXiv preprint arXiv:2110.02926, 2021
-
[20]
Gradient descent finds global minima of deep neural networks
Simon Du, Jason Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient descent finds global minima of deep neural networks. InInternational Conference on Machine Learning, pages 1675–
-
[21]
Takashi Furuya, Maarten V. de Hoop, and Gabriel Peyr´ e. Transformers are universal in-context learners. InInternational Conference on Learning Representations, 2025
work page 2025
-
[22]
Cheng Gao, Yuan Cao, Zihao Li, Yihan He, Mengdi Wang, Han Liu, Jason Klusowski, and Jian- qing Fan. Global convergence in training large-scale Transformers.Advances in Neural Information Processing Systems, 37:29213–29284, 2024
work page 2024
-
[23]
Shivam Garg, Dimitris Tsipras, Percy S. Liang, and Gregory Valiant. What can Transformers learn in-context? a case study of simple function classes.Advances in Neural Information Processing Systems, 35:30583–30598, 2022
work page 2022
-
[24]
Borjan Geshkovski, Cyril Letrouit, Yury Polyanskiy, and Philippe Rigollet. The emergence of clusters in self-attention dynamics.Advances in Neural Information Processing Systems, 36:57026– 57037, 2023
work page 2023
-
[25]
Daniel Hauer and Jos´ e Maz´ on. Kurdyka– Lojasiewicz–Simon inequality for gradient flows in metric spaces.Transactions of the American Mathematical Society, 372(7):4917–4976, 2019
work page 2019
-
[26]
Deep residual learning for image recognition
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. InProceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016
work page 2016
-
[27]
Noboru Isobe. A convergence result of a continuous model of deep learning via Lojasiewicz–Simon inequality.arXiv preprint arXiv:2311.15365, 2023
-
[28]
Arthur Jacot, Franck Gabriel, and Cl´ ement Hongler. Neural tangent kernel: Convergence and generalization in neural networks.Advances in Neural Information Processing Systems, 31, 2018
work page 2018
-
[29]
Richard Jordan, David Kinderlehrer, and Felix Otto. The variational formulation of the Fokker– Planck equation.SIAM journal on mathematical analysis, 29(1):1–17, 1998
work page 1998
-
[30]
Chaoyue Liu, Libin Zhu, and Mikhail Belkin. On the linearity of large non-linear models: When and why the tangent kernel is constant.Advances in Neural Information Processing Systems, 33, 2020
work page 2020
-
[31]
Yiping Lu, Chao Ma, Yulong Lu, Jianfeng Lu, and Lexing Ying. A mean field analysis of deep ResNet and beyond: Towards provably optimization via overparameterization from depth. In International Conference on Machine Learning, pages 6426–6436. PMLR, 2020. 32
work page 2020
-
[32]
Arvind Mahankali, Tatsunori B. Hashimoto, and Tengyu Ma. One step of gradient descent is provably the optimal in-context learner with one layer of linear self-attention. InInternational Conference on Learning Representations, 2024
work page 2024
-
[33]
Scaling ResNets in the large-depth regime.Journal of Machine Learning Research, 26(56):1–48, 2025
Pierre Marion, Adeline Fermanian, G´ erard Biau, and Jean-Philippe Vert. Scaling ResNets in the large-depth regime.Journal of Machine Learning Research, 26(56):1–48, 2025
work page 2025
-
[34]
Song Mei, Andrea Montanari, and Phan-Minh Nguyen. A mean field view of the landscape of two- layer neural networks.Proceedings of the National Academy of Sciences, 115(33):E7665–E7671, 2018
work page 2018
-
[35]
Micchelli, Yuesheng Xu, and Haizhang Zhang
Charles A. Micchelli, Yuesheng Xu, and Haizhang Zhang. Universal kernels.Journal of Machine Learning Research, 7(95):2651–2667, 2006
work page 2006
-
[36]
Jan Peszek and David Poyato. Heterogeneous gradient flows in the topology of fibered optimal transport.Calculus of Variations and Partial Differential Equations, 62(9):258, 2023
work page 2023
-
[37]
Benedetto Piccoli, Francesco Rossi, and Emmanuel Tr´ elat. Control to flocking of the kinetic Cucker–Smale model.SIAM Journal on Mathematical Analysis, 47(6):4685–4719, 2015
work page 2015
-
[38]
Zhen Qin, Jinxin Zhou, Jiachen Jiang, and Zhihui Zhu. On the convergence of gradient descent on learning Transformers with residual connections.IEEE Signal Processing Letters, pages 1–5, 2026
work page 2026
-
[39]
Grant Rotskoff and Eric Vanden-Eijnden. Parameters as interacting particles: long time conver- gence and asymptotic error scaling of neural networks.Advances in neural information processing systems, 31, 2018
work page 2018
-
[40]
Sander, Pierre Ablin, Mathieu Blondel, and Gabriel Peyr´ e
Michael E. Sander, Pierre Ablin, Mathieu Blondel, and Gabriel Peyr´ e. Sinkformers: Transform- ers with doubly stochastic attention. InInternational Conference on Artificial Intelligence and Statistics, pages 3515–3530. PMLR, 2022
work page 2022
-
[41]
Filippo Santambrogio.Optimal Transport for Applied Mathematicians: Calculus of Variations, PDEs, and Modeling, volume 87 ofProgress in Nonlinear Differential Equations and Their Ap- plications. Birkh¨ auser, Cham, 2015
work page 2015
-
[42]
Filippo Santambrogio. Euclidean, metric, and Wasserstein gradient flows: An overview.Bulletin of Mathematical Sciences, 7:87–154, 2017
work page 2017
-
[43]
Bingqing Song, Boran Han, Shuai Zhang, Jie Ding, and Mingyi Hong. Unraveling the gradient descent dynamics of Transformers.Advances in Neural Information Processing Systems, 37:92317– 92351, 2024
work page 2024
-
[44]
Attention is all you need.Advances in neural information processing systems, 30, 2017
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need.Advances in neural information processing systems, 30, 2017
work page 2017
-
[45]
C´ edric Villani.Optimal transport: old and new, volume 338. Springer, 2009
work page 2009
-
[46]
Yongtao Wu, Fanghui Liu, Grigorios Chrysos, and Volkan Cevher. On the convergence of encoder- only shallow Transformers.Advances in Neural Information Processing Systems, 36:52197–52237, 2023
work page 2023
-
[47]
Springer Berlin Heidelberg, Berlin, Heidelberg, 2010
Laurent Younes.Shapes and Diffeomorphisms, volume 171 ofApplied Mathematical Sciences. Springer Berlin Heidelberg, Berlin, Heidelberg, 2010
work page 2010
-
[48]
Hongyi Zhang, Yann N Dauphin, and Tengyu Ma. Fixup initialization: Residual learning without normalization. InInternational Conference on Learning Representations, 2019. 33 A Proofs of section 2 This section is devoted to proving the results in section 2. Recall that we are considering Attention layers of the form eq. (1), parameterized by tripletsθ= (Q, ...
work page 2019
-
[49]
= 0, which leads to the desired result. 43 The above result relies on the two following lemmas, showing local Lipschitz regularity of the adjoint variables and of the gradient field w.r.t. the parameterization. Lemma 9.Fix an input token distributionµ∈ P c(Rd)and an input tokenx∈R d. LetSbe some compact set s.t.Supp(µ)∪ {x} ⊂S. Then the associated adjoint...
-
[50]
There exists a functionC 0 ∈L 1([0,1])s.t. for everys∈[0,1]it holds for every radiusR >0 and everyµ∈ P c(Rd)supported inB(0, R): ∥Vs[µ]∥C0 ≤C 0(s)(1 +R)
-
[51]
For every radiusR >0there exist functionsL R, MR ∈L 1([0,1])s.t. for everyµ∈ P c(Rd) supported inB(0, R)it holds: sup x,y∈Rd ∥Vs[µ](x)−V s[µ](y)∥ ≤L R(s)∥x−y∥ and for everyµ, ν∈ P c(Rd)supported onB(0, R): ∥Vs[µ]−V s[ν]∥C0(B(0,R)) ≤M R(s)W1(µ, ν). Under those assumptions, one can show the existence and uniqueness of the transport equation with velocity-fi...
-
[52]
for everyx∈ Xthe mapt∈I7→f(t, x)∈ Xis measurable, 1
-
[53]
for a.e.t∈I, the mapx∈ X 7→f(t, x)∈ Xis continuous. •We sayfislocallyL 1-Lipschitzif for every bounded subsetV ⊂ Xthere exists a function LV ∈L 1 loc(I)such that for a.e.t∈Iit holds: ∀x, y∈ V,∥f(t, x)−f(t, y)∥ ≤L V(t)∥x−y∥. •We sayfhasL 1-linear growth if there exists a functionC∈L 1 loc(I)such that for a.e.t∈Iit holds: ∀x∈ X,∥f(t, x)∥ ≤C(t)(1 +∥x∥). Theo...
discussion (0)
Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.