Closed-Form Last Layer Optimization
Pith reviewed 2026-05-18 09:54 UTC · model grok-4.3
The pith
Treating the last layer weights as the closed-form optimum for the current backbone reduces neural net training to backbone optimization alone and converges to the global solution in the NTK regime.
A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.
Core claim
By expressing the last-layer weights as the ordinary-least-squares solution given the backbone features, the training objective becomes a function solely of the backbone parameters; the stochastic variant that balances current-batch loss with accumulated information converges to the same optimum that full-batch gradient descent would reach in the neural-tangent-kernel regime.
What carries the argument
The closed-form last-layer solution, obtained by solving the linear least-squares problem for the output weights given the current backbone activations, which allows the loss to be written and differentiated only with respect to the backbone.
If this is right
- The method is equivalent to alternating exact last-layer solves with gradient steps on the backbone.
- In the NTK regime the stochastic adaptation still converges to the globally optimal solution.
- One-step analysis shows a quantifiable reduction in loss compared with ordinary SGD.
- Empirical results on neural-operator and causal-inference regression tasks show lower final loss than SGD or Adam.
Where Pith is reading between the lines
- The same substitution trick could be attempted for any layer whose output weights have a closed-form optimum, not only the final layer.
- If an approximate closed-form update can be derived for cross-entropy loss, the approach might extend beyond regression.
- The reported gains may persist outside the strict NTK regime even if the convergence proof does not.
Load-bearing premise
The training loss is squared error, so that the optimal last-layer weights admit an exact closed-form expression.
What would settle it
Running the method and standard SGD for the same number of steps on a regression task and observing that the proposed method reaches neither lower loss nor faster convergence would falsify the claimed practical and theoretical gains.
Figures
read the original abstract
Neural networks are typically optimized with variants of stochastic gradient descent. Under a squared loss, however, the optimal solution to the linear last layer weights is known in closed-form. We propose to leverage this during optimization, treating the last layer as a function of the backbone parameters, and optimizing solely for these parameters. We show this is equivalent to alternating between gradient descent steps on the backbone and closed-form updates on the last layer. We adapt the method for the setting of stochastic gradient descent, by trading off the loss on the current batch against the accumulated information from previous batches. We provide theoretical analyses showing convergence of the method to an optimal solution in the neural tangent kernel regime, as well as quantifying the gains compared to standard SGD in a one-step analysis. Finally, we demonstrate the effectiveness of our approach compared with SGD and Adam on a squared loss in several regression tasks, including neural operators and causal inference.
Editorial analysis
A structured set of objections, weighed in public.
Referee Report
Summary. The paper proposes Closed-Form Last Layer Optimization for neural networks trained under squared loss. By expressing the optimal last-layer weights in closed form as a function of the backbone parameters, the method optimizes only the backbone via gradient steps; this is shown to be equivalent to alternating between backbone gradient descent and exact least-squares solves for the last layer. The approach is extended to the stochastic setting by maintaining running accumulators for feature statistics (ΣΦᵀΦ and ΣΦᵀy) that trade off the current mini-batch against historical data. Theoretical results establish convergence to an optimal joint solution in the neural tangent kernel regime together with a one-step gain analysis relative to vanilla SGD. Experiments on regression tasks (neural operators, causal inference) compare the method favorably to SGD and Adam.
Significance. If the stochastic adaptation and NTK convergence hold, the method supplies a principled way to exploit the closed-form last-layer optimum during training, potentially reducing the number of effective parameters optimized by gradient descent. The one-step analysis and NTK lens provide concrete theoretical grounding, while the regression experiments on neural operators and causal inference demonstrate applicability to non-trivial tasks. These elements strengthen the contribution beyond a purely heuristic trick.
major comments (2)
- [Stochastic Adaptation] Stochastic adaptation (around the description of running accumulators ΣΦᵀΦ and ΣΦᵀy): after each backbone parameter update the feature map φ(·;θ) changes for all previously seen points, so the accumulated statistics no longer correspond to the current θ. The linear system solved at each step is therefore not the exact least-squares problem for the present backbone. The NTK analysis assumes infinitesimal feature drift; the manuscript should supply an explicit error bound or finite-step analysis showing that the mismatch remains controlled for the step sizes used in practice.
- [Theoretical Analysis] Convergence claim in the NTK regime: the deterministic alternating procedure reaches the joint optimum by construction, but the stochastic version relies on the accumulators remaining approximately consistent with the drifting features. Please state the precise conditions (step-size scaling, batch size, NTK width) under which the claimed convergence to the global optimum of the joint problem continues to hold.
minor comments (2)
- [One-step Analysis] One-step analysis: report the numerical magnitude of the predicted gain (e.g., reduction in loss or effective learning-rate improvement) so readers can judge its practical relevance.
- [Experiments] Experiments: include standard-error bars or multiple random seeds for all reported metrics to allow assessment of variability.
Simulated Author's Rebuttal
We thank the referee for their constructive and insightful comments. We address each major comment below and propose targeted revisions to clarify the stochastic approximation and the precise conditions for the NTK convergence result.
read point-by-point responses
-
Referee: [Stochastic Adaptation] Stochastic adaptation (around the description of running accumulators ΣΦᵀΦ and ΣΦᵀy): after each backbone parameter update the feature map φ(·;θ) changes for all previously seen points, so the accumulated statistics no longer correspond to the current θ. The linear system solved at each step is therefore not the exact least-squares problem for the present backbone. The NTK analysis assumes infinitesimal feature drift; the manuscript should supply an explicit error bound or finite-step analysis showing that the mismatch remains controlled for the step sizes used in practice.
Authors: We agree that the running accumulators yield an approximation rather than the exact least-squares solution once θ is updated, since all prior features drift. The method intentionally employs a convex combination of the current mini-batch statistics with historical accumulators (controlled by a decay factor) to enable online training without full re-computation. This is analogous to momentum or exponential moving average techniques. Under the NTK regime of the analysis, per-step feature changes are infinitesimal for small learning rates, keeping the mismatch controlled. We will revise the manuscript to add an explicit discussion of this approximation, a qualitative error argument based on the NTK drift bound, and new empirical figures quantifying the deviation from the exact closed-form solution for the step sizes and batch sizes used in the experiments. A fully rigorous finite-step error bound for arbitrary widths and step sizes lies beyond the present scope. revision: partial
-
Referee: [Theoretical Analysis] Convergence claim in the NTK regime: the deterministic alternating procedure reaches the joint optimum by construction, but the stochastic version relies on the accumulators remaining approximately consistent with the drifting features. Please state the precise conditions (step-size scaling, batch size, NTK width) under which the claimed convergence to the global optimum of the joint problem continues to hold.
Authors: We thank the referee for requesting this clarification. The convergence result is established in the neural tangent kernel regime. We will revise the theoretical section to state the conditions explicitly: (i) the infinite-width NTK limit, (ii) step sizes η scaled as o(1) (specifically satisfying the standard NTK stability condition η < 2/λ_max of the limiting kernel Gram matrix), and (iii) any fixed batch size B ≥ 1. Under these conditions the feature map remains sufficiently close to its initial value that the accumulators track the current features with vanishing error, recovering convergence to the joint optimum. The deterministic alternating procedure is recovered exactly in the limit of infinitesimal steps; the stochastic version inherits the same guarantee under the stated scaling. revision: yes
- A complete, non-asymptotic finite-step error bound on the accumulator mismatch that holds for finite-width networks and arbitrary step sizes would require substantial new technical development outside the NTK framework and is not provided in the current manuscript.
Circularity Check
Derivation chain is self-contained with no circular reductions
full rationale
The paper establishes a mathematical equivalence between joint optimization of backbone parameters (with last layer expressed as a closed-form function of them) and alternating gradient steps on the backbone with exact least-squares solves on the last layer; this follows directly from the normal equations for squared loss without redefining any quantity in terms of the target result. The stochastic adaptation maintains running accumulators for the Gram matrix and cross-term while trading off current-batch loss, and the convergence claim is derived under the standard external neural tangent kernel regime (infinitesimal feature drift limit) rather than by fitting parameters to the method's own outputs or importing uniqueness from self-citation. No step renames a known empirical pattern, smuggles an ansatz via prior work by the same authors, or presents a fitted input as a prediction. The one-step gain analysis is likewise independent of the main convergence argument. The derivation therefore remains non-circular and externally grounded.
Axiom & Free-Parameter Ledger
axioms (2)
- domain assumption Squared loss admits a closed-form optimum for the final linear layer
- domain assumption Neural tangent kernel regime approximates the training dynamics of the backbone
Lean theorems connected to this paper
-
IndisputableMonolith/Cost/FunctionalEquation.leanwashburn_uniqueness_aczel unclear?
unclearRelation between the paper passage and the cited Recognition theorem.
We propose to leverage this during optimization, treating the last layer as a function of the backbone parameters, and optimizing solely for these parameters. We show this is equivalent to alternating between gradient descent steps on the backbone and closed-form updates on the last layer.
-
IndisputableMonolith/Foundation/BranchSelection.leanbranch_selection unclear?
unclearRelation between the paper passage and the cited Recognition theorem.
We adapt the method for the setting of stochastic gradient descent, by trading off the loss on the current batch against the accumulated information from previous batches.
What do these tags mean?
- matches
- The paper's claim is directly supported by a theorem in the formal canon.
- supports
- The theorem supports part of the paper's argument, but the paper may add assumptions or extra steps.
- extends
- The paper goes beyond the formal theorem; the theorem is a base layer rather than the whole result.
- uses
- The paper appears to rely on the theorem as machinery.
- contradicts
- The paper's claim conflicts with a theorem or certificate in the canon.
- unclear
- Pith found a possible connection, but the passage is too broad, indirect, or ambiguous to say the theorem truly supports the claim.
Forward citations
Cited by 3 Pith papers
-
Doubly Robust Proxy Causal Learning with Neural Mean Embeddings
A neural doubly robust proxy causal learning framework using mean embeddings for treatment bridges provides consistent estimators for causal dose-response functions under unobserved confounding for continuous and stru...
-
Rethinking Neural Network Learning Rates: A Stackelberg Perspective
Non-uniform learning rates correspond to a Stackelberg reformulation of the training objective whose two-time-scale alternating gradient descent yields finite-time convergence and can accelerate training through stron...
-
Context-Aware Model Predictive Control for Microgrid Energy Management via LLMs
InstructMPC uses an LLM plus tunable last layer to map operational context to disturbance trajectories for MPC, proving an O(sqrt(T log T)) regret bound for linear systems and showing lower grid costs on the OpenCEM m...
Reference graph
Works this paper leans on
-
[1]
arXiv:2504.18208 [cs]. Rapha¨el Berthier, Andrea Montanari, and Kangjie Zhou. Learning Time-Scales in Two- Layers Neural Networks.Foundations of Computational Mathematics, August
-
[2]
doi: 10.1007/s10208-024-09664-9
ISSN 1615-3383. doi: 10.1007/s10208-024-09664-9. URLhttps://doi.org/10.1007/ s10208-024-09664-9. Alberto Bietti, Joan Bruna, and Loucas Pillaud-Vivien. On learning Gaussian multi-index models with gradient flow part I: General properties and two-timescale learning.Communications on Pure and Applied Mathematics, n/a,
-
[3]
ISSN 1097-0312. doi: 10.1002/cpa.70006. Andrew Brock, Soham De, Samuel L. Smith, and Karen Simonyan. High-performance large-scale image recognition without normalization,
-
[4]
James Harrison, John Willes, and Jasper Snoek
URL https://proceedings.neurips.cc/paper_files/paper/2023/hash/ ac24656b0b5f543b202f748d62041637-Abstract-Conference.html. James Harrison, John Willes, and Jasper Snoek. Variational Bayesian Last Layers. October
work page 2023
-
[5]
Deep residual learning for image recognition
doi: 10.1109/CVPR.2016.90. Like Hui and Mikhail Belkin. Evaluation of neural architectures trained with square loss vs cross- entropy in classifications tasks. InInternational Conference on Learning Representations,
-
[6]
URLhttps://papers.nips.cc/paper_files/ paper/2018/hash/5a4be1fa34e62bb8a6ec6b91d2462f5a-Abstract.html. Felix Koehler. Machine Learning and Simulation,
work page 2018
-
[7]
toronto.edu/˜kriz/learning-features-2009-TR.pdf
URLhttps://www.cs. toronto.edu/˜kriz/learning-features-2009-TR.pdf. Zongyi Li, Nikola Borislavov Kovachki, Kamyar Azizzadenesheli, Burigede Liu, Kaushik Bhat- tacharya, Andrew M. Stuart, and Anima Anandkumar. Fourier neural operator for parametric partial differential equations. In9th International Conference on Learning Representations, ICLR 2021, Virtua...
work page 2009
-
[8]
URLhttps://proceedings.neurips.cc/paper_files/paper/ 2023/hash/cd062f8003e38f55dcb93df55b2683d6-Abstract-Conference. html. Loic Matthey, Irina Higgins, Demis Hassabis, and Alexander Lerchner. dsprites: Disentanglement testing sprites dataset. https://github.com/deepmind/dsprites-dataset/,
work page 2023
-
[9]
Bharat Singh, Soham De, Yangmuzi Zhang, Thomas Goldstein, and Gavin Taylor
URLhttps://proceedings.neurips.cc/paper_files/paper/2024/ hash/19ae2b95d3831c14373271112f189a22-Abstract-Conference.html. Bharat Singh, Soham De, Yangmuzi Zhang, Thomas Goldstein, and Gavin Taylor. Layer-Specific Adaptive Learning Rates for Deep Networks. In2015 IEEE 14th International Conference on Machine Learning and Applications (ICMLA), pp. 364–368, December
work page 2024
-
[10]
URLhttps://ieeexplore.ieee.org/document/7424337
doi: 10.1109/ ICMLA.2015.113. URLhttps://ieeexplore.ieee.org/document/7424337. Simo S ¨arkk¨a.Bayesian Filtering and Smoothing. Institute of Mathematical Statis- tics Textbooks. Cambridge University Press, Cambridge,
-
[11]
ISBN 978- 1-107-03065-7. doi: 10.1017/CBO9781139344203. URLhttps://www. cambridge.org/core/books/bayesian-filtering-and-smoothing/ C372FB31C5D9A100F8476C1B23721A67. Shokichi Takakura and Taiji Suzuki. Mean-field Analysis on Two-layer Neural Networks from a Kernel Perspective, April
-
[12]
arXiv:2403.14917 [cs]. Michalis Titsias, Alexandre Galashov, Amal Rannen-Triki, Razvan Pascanu, Yee Whye Teh, and Jorg Bornschein. Kalman filter for online classification of non-stationary data. InThe Twelfth International Conference on Learning Representations,
-
[13]
Liyuan Xu, Yutian Chen, Siddarth Srinivasan, Nando de Freitas, Arnaud Doucet, and Arthur Gretton
URL https://proceedings.neurips.cc/paper_files/paper/2024/hash/ 3e0f495e21bdbdb4251792d0fff57928-Abstract-Conference.html. Liyuan Xu, Yutian Chen, Siddarth Srinivasan, Nando de Freitas, Arnaud Doucet, and Arthur Gretton. Learning Deep Features in Instrumental Variable Regression. October
work page 2024
-
[14]
Liyuan Xu, Heishiro Kanagawa, and Arthur Gretton
URL https://openreview.net/forum?id=sy4Kg_ZQmS7. Liyuan Xu, Heishiro Kanagawa, and Arthur Gretton. Deep proxy causal learning and its application to confounded bandit policy evaluation. InProceedings of the 35th International Conference on Neural Information Processing Systems, NIPS ’21, Red Hook, NY , USA, 2021a. Curran Asso- ciates Inc. ISBN 97817138453...
work page 2021
-
[15]
Large Batch Training of Convolutional Networks
URLhttp://arxiv.org/abs/1708.03888. arXiv:1708.03888 [cs]. Yihua Zhang, Prashant Khanduri, Ioannis Tsaknakis, Yuguang Yao, Mingyi Hong, and Sijia Liu. An Introduction to Bilevel Optimization: Foundations and applications in signal processing and machine learning.IEEE Signal Processing Magazine, 41(1):38–59, April
work page internal anchor Pith review Pith/arXiv arXiv
-
[16]
ISSN 1558-0792. doi: 10.1109/MSP.2024.3358284. 12 Preprint A Kalman Filter Interpretation of the Proximal Algorithm 13 B Alternative Algorithm 14 C Proofs for the Theoretical Analysis of the Loss 14 C.1 Neural Tangent Kernel Infinite Width Limit . . . . . . . . . . . . . . . . . . . . . 15 D Deep Feature Instrumental Variable Regression 17 E Experimental ...
-
[17]
for a similar discussion. First, we assume that the model fits the data perfectly during optimization, so we have the likelihood p(yi |x i, Wt, θt) =N(y i |W tϕθt(xi), σ2 Y I)(22) whereIis the identity matrix andσ 2 Y is some hyperparameter controlling the variance of the outputs and(x i, yi)∈ B t. Next, we assumeW t evolves like a random walk with Gaussi...
work page 2013
-
[18]
But this is impossible since we showed thatL ⋆ F is not convex
Remark that a result such as Theorem 1 cannot be extended to second derivatives, otherwise we could differentiateL ⋆ F twice by keepingW ⋆ F(ϕ)constant, and would obtain that the Hessian ofL ⋆ F is positive semi-definite since so is the one of the squared loss. But this is impossible since we showed thatL ⋆ F is not convex. C.1 NEURALTANGENTKERNELINFINITE...
work page 2018
-
[19]
This is an image dataset described by five latent parameters(shape,scale,rotation,posX,posY)
and we consider a slightly modified version ofd-spirtestask (Matthey et al., 2017). This is an image dataset described by five latent parameters(shape,scale,rotation,posX,posY). The images are 64×64 = 4096dimensional. In this experiment, the authors fix theshapeparameter toheart, i.e., they only used heart-shaped images. The authors generated data for IV ...
work page 2017
-
[20]
(2020) led to essentially a constant function (in expectation)
The choice of this structural function was motivated by (Xu et al., 2021a), because the original choice described in Xu et al. (2020) led to essentially a constant function (in expectation). For our experiments, we use different batch sizes. The DFIV method (Xu et al.,
work page 2020
-
[21]
essentially corresponds to two-stage “ℓ 2 c.f. ridge (β)” where we haveβ 1 andβ 2 parameters for the first and second stage correspondingly. In our proximal method, DFIV proximal, as described in Section D, we have three parametersλ 1,λ 2 for the first and second stage proximal updates andλ 1,2 for first- stage update inside the second stage. In practice,...
work page 2020
-
[22]
and whenever performance is reported, takes the first stage and second stage backbone parameters, and re-estimates the corresponding last layers on the whole10000training set. In Figure 3 it is repre- sented by the solid line. The second strategy just takes the current estimates of the last layers. In Figure 3 it is represented by the dashed line. The swe...
work page 2021
discussion (0)
Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.