pith. sign in

arxiv: 2510.04606 · v2 · submitted 2025-10-06 · 💻 cs.LG · stat.ML

Closed-Form Last Layer Optimization

Pith reviewed 2026-05-18 09:54 UTC · model grok-4.3

classification 💻 cs.LG stat.ML
keywords last-layer optimizationclosed-form solutionsquared lossneural tangent kernelstochastic gradient descentbackbone parametersregression tasks
0
0 comments X p. Extension

The pith

Treating the last layer weights as the closed-form optimum for the current backbone reduces neural net training to backbone optimization alone and converges to the global solution in the NTK regime.

A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.

The paper shows that under squared loss the optimal last-layer weights can be written in closed form as a function of the backbone activations. Substituting this expression into the loss turns the entire optimization into a problem over backbone parameters only. The resulting procedure is mathematically equivalent to running gradient descent on the backbone while resetting the last layer to its exact optimum after each step. For the stochastic case the authors replace the exact solve with a running estimate that trades off the current batch against accumulated statistics from prior batches. Theory establishes that the stochastic version still reaches the optimal NTK solution, and experiments on regression tasks report faster loss reduction than both SGD and Adam.

Core claim

By expressing the last-layer weights as the ordinary-least-squares solution given the backbone features, the training objective becomes a function solely of the backbone parameters; the stochastic variant that balances current-batch loss with accumulated information converges to the same optimum that full-batch gradient descent would reach in the neural-tangent-kernel regime.

What carries the argument

The closed-form last-layer solution, obtained by solving the linear least-squares problem for the output weights given the current backbone activations, which allows the loss to be written and differentiated only with respect to the backbone.

If this is right

  • The method is equivalent to alternating exact last-layer solves with gradient steps on the backbone.
  • In the NTK regime the stochastic adaptation still converges to the globally optimal solution.
  • One-step analysis shows a quantifiable reduction in loss compared with ordinary SGD.
  • Empirical results on neural-operator and causal-inference regression tasks show lower final loss than SGD or Adam.

Where Pith is reading between the lines

These are editorial extensions of the paper, not claims the author makes directly.

  • The same substitution trick could be attempted for any layer whose output weights have a closed-form optimum, not only the final layer.
  • If an approximate closed-form update can be derived for cross-entropy loss, the approach might extend beyond regression.
  • The reported gains may persist outside the strict NTK regime even if the convergence proof does not.

Load-bearing premise

The training loss is squared error, so that the optimal last-layer weights admit an exact closed-form expression.

What would settle it

Running the method and standard SGD for the same number of steps on a regression task and observing that the proposed method reaches neither lower loss nor faster convergence would falsify the claimed practical and theoretical gains.

Figures

Figures reproduced from arXiv: 2510.04606 by Alexandre Galashov, Arthur Gretton, Liyuan Xu, Natha\"el Da Costa, Philipp Hennig.

Figure 1
Figure 1. Figure 1: The squared loss landscape of a two-parameter neural [PITH_FULL_IMAGE:figures/full_fig_p001_1.png] view at source ↗
Figure 2
Figure 2. Figure 2 [PITH_FULL_IMAGE:figures/full_fig_p007_2.png] view at source ↗
Figure 3
Figure 3. Figure 3: DFIV results. X-axis is the number of iterations, Y-axis is a test set MSE. Each column corresponds to a different batch size. Different colors indicate different methods. Solid lines use the last layer re-estimated on the entire training set, while dashed lines use current last layer estimates. We use a rolling average with window size 5 to smooth the curves. Application to classification. We perform expe… view at source ↗
Figure 4
Figure 4. Figure 4: CIFAR-10 results. X-axis is the number of iterations, Y-axis is a test set accuracy. Each column corresponds to a different batch size. Different colors indicate different methods. The results are presented for CIFAR-10 in [PITH_FULL_IMAGE:figures/full_fig_p008_4.png] view at source ↗
Figure 5
Figure 5. Figure 5: CIFAR-100 results. X-axis is the number of iterations, Y-axis is a test set accuracy. Each column corresponds to a different batch size. Different colors indicate different methods. Impact of λ and β. In [PITH_FULL_IMAGE:figures/full_fig_p008_5.png] view at source ↗
Figure 6
Figure 6. Figure 6: Dependence on hyperparameters on CIFAR-100. X-axis is the number of iterations, Y-axis is a test set accuracy. Left, ablation over λ for ‘ℓ2 c.f. proximal (λ)”. Center, ablation over β for “ℓ2 c.f. ridge (β)”. Right, the best learning rate per method. the backbone is updated using the last layer from the same batch, we hypothesize that this leads to more correlated updates and which may under-perform, whil… view at source ↗
Figure 7
Figure 7. Figure 7: Comparison of Algorithm 2 and Algorithm 1. X-axis is the number of iterations, Y-axis is a test set accuracy. A column indicates a batch size while a color represents an algorithm. Additional ablations. We ran an ablation on design choices for our method, see Section F.1. We only provide a short summary here. We first verified that the inclusion of a bias term in the last layer did not lead to a difference… view at source ↗
Figure 8
Figure 8. Figure 8: ImageNet results. X-axis is the number of iterations, Y-axis is a test set accuracy. Each column corresponds to a different batch size. Different colors indicate different methods. settings than DFIV, such as offline reinforcement learning (Chen et al., 2022b) and proxy variables regression (Xu et al., 2021b). Moreover, understanding how to define parameters λ per last layer dimension and adapt these over … view at source ↗
read the original abstract

Neural networks are typically optimized with variants of stochastic gradient descent. Under a squared loss, however, the optimal solution to the linear last layer weights is known in closed-form. We propose to leverage this during optimization, treating the last layer as a function of the backbone parameters, and optimizing solely for these parameters. We show this is equivalent to alternating between gradient descent steps on the backbone and closed-form updates on the last layer. We adapt the method for the setting of stochastic gradient descent, by trading off the loss on the current batch against the accumulated information from previous batches. We provide theoretical analyses showing convergence of the method to an optimal solution in the neural tangent kernel regime, as well as quantifying the gains compared to standard SGD in a one-step analysis. Finally, we demonstrate the effectiveness of our approach compared with SGD and Adam on a squared loss in several regression tasks, including neural operators and causal inference.

Editorial analysis

A structured set of objections, weighed in public.

Desk editor's note, referee report, simulated authors' rebuttal, and a circularity audit. Tearing a paper down is the easy half of reading it; the pith above is the substance, this is the friction.

Referee Report

2 major / 2 minor

Summary. The paper proposes Closed-Form Last Layer Optimization for neural networks trained under squared loss. By expressing the optimal last-layer weights in closed form as a function of the backbone parameters, the method optimizes only the backbone via gradient steps; this is shown to be equivalent to alternating between backbone gradient descent and exact least-squares solves for the last layer. The approach is extended to the stochastic setting by maintaining running accumulators for feature statistics (ΣΦᵀΦ and ΣΦᵀy) that trade off the current mini-batch against historical data. Theoretical results establish convergence to an optimal joint solution in the neural tangent kernel regime together with a one-step gain analysis relative to vanilla SGD. Experiments on regression tasks (neural operators, causal inference) compare the method favorably to SGD and Adam.

Significance. If the stochastic adaptation and NTK convergence hold, the method supplies a principled way to exploit the closed-form last-layer optimum during training, potentially reducing the number of effective parameters optimized by gradient descent. The one-step analysis and NTK lens provide concrete theoretical grounding, while the regression experiments on neural operators and causal inference demonstrate applicability to non-trivial tasks. These elements strengthen the contribution beyond a purely heuristic trick.

major comments (2)
  1. [Stochastic Adaptation] Stochastic adaptation (around the description of running accumulators ΣΦᵀΦ and ΣΦᵀy): after each backbone parameter update the feature map φ(·;θ) changes for all previously seen points, so the accumulated statistics no longer correspond to the current θ. The linear system solved at each step is therefore not the exact least-squares problem for the present backbone. The NTK analysis assumes infinitesimal feature drift; the manuscript should supply an explicit error bound or finite-step analysis showing that the mismatch remains controlled for the step sizes used in practice.
  2. [Theoretical Analysis] Convergence claim in the NTK regime: the deterministic alternating procedure reaches the joint optimum by construction, but the stochastic version relies on the accumulators remaining approximately consistent with the drifting features. Please state the precise conditions (step-size scaling, batch size, NTK width) under which the claimed convergence to the global optimum of the joint problem continues to hold.
minor comments (2)
  1. [One-step Analysis] One-step analysis: report the numerical magnitude of the predicted gain (e.g., reduction in loss or effective learning-rate improvement) so readers can judge its practical relevance.
  2. [Experiments] Experiments: include standard-error bars or multiple random seeds for all reported metrics to allow assessment of variability.

Simulated Author's Rebuttal

2 responses · 1 unresolved

We thank the referee for their constructive and insightful comments. We address each major comment below and propose targeted revisions to clarify the stochastic approximation and the precise conditions for the NTK convergence result.

read point-by-point responses
  1. Referee: [Stochastic Adaptation] Stochastic adaptation (around the description of running accumulators ΣΦᵀΦ and ΣΦᵀy): after each backbone parameter update the feature map φ(·;θ) changes for all previously seen points, so the accumulated statistics no longer correspond to the current θ. The linear system solved at each step is therefore not the exact least-squares problem for the present backbone. The NTK analysis assumes infinitesimal feature drift; the manuscript should supply an explicit error bound or finite-step analysis showing that the mismatch remains controlled for the step sizes used in practice.

    Authors: We agree that the running accumulators yield an approximation rather than the exact least-squares solution once θ is updated, since all prior features drift. The method intentionally employs a convex combination of the current mini-batch statistics with historical accumulators (controlled by a decay factor) to enable online training without full re-computation. This is analogous to momentum or exponential moving average techniques. Under the NTK regime of the analysis, per-step feature changes are infinitesimal for small learning rates, keeping the mismatch controlled. We will revise the manuscript to add an explicit discussion of this approximation, a qualitative error argument based on the NTK drift bound, and new empirical figures quantifying the deviation from the exact closed-form solution for the step sizes and batch sizes used in the experiments. A fully rigorous finite-step error bound for arbitrary widths and step sizes lies beyond the present scope. revision: partial

  2. Referee: [Theoretical Analysis] Convergence claim in the NTK regime: the deterministic alternating procedure reaches the joint optimum by construction, but the stochastic version relies on the accumulators remaining approximately consistent with the drifting features. Please state the precise conditions (step-size scaling, batch size, NTK width) under which the claimed convergence to the global optimum of the joint problem continues to hold.

    Authors: We thank the referee for requesting this clarification. The convergence result is established in the neural tangent kernel regime. We will revise the theoretical section to state the conditions explicitly: (i) the infinite-width NTK limit, (ii) step sizes η scaled as o(1) (specifically satisfying the standard NTK stability condition η < 2/λ_max of the limiting kernel Gram matrix), and (iii) any fixed batch size B ≥ 1. Under these conditions the feature map remains sufficiently close to its initial value that the accumulators track the current features with vanishing error, recovering convergence to the joint optimum. The deterministic alternating procedure is recovered exactly in the limit of infinitesimal steps; the stochastic version inherits the same guarantee under the stated scaling. revision: yes

standing simulated objections not resolved
  • A complete, non-asymptotic finite-step error bound on the accumulator mismatch that holds for finite-width networks and arbitrary step sizes would require substantial new technical development outside the NTK framework and is not provided in the current manuscript.

Circularity Check

0 steps flagged

Derivation chain is self-contained with no circular reductions

full rationale

The paper establishes a mathematical equivalence between joint optimization of backbone parameters (with last layer expressed as a closed-form function of them) and alternating gradient steps on the backbone with exact least-squares solves on the last layer; this follows directly from the normal equations for squared loss without redefining any quantity in terms of the target result. The stochastic adaptation maintains running accumulators for the Gram matrix and cross-term while trading off current-batch loss, and the convergence claim is derived under the standard external neural tangent kernel regime (infinitesimal feature drift limit) rather than by fitting parameters to the method's own outputs or importing uniqueness from self-citation. No step renames a known empirical pattern, smuggles an ansatz via prior work by the same authors, or presents a fitted input as a prediction. The one-step gain analysis is likewise independent of the main convergence argument. The derivation therefore remains non-circular and externally grounded.

Axiom & Free-Parameter Ledger

0 free parameters · 2 axioms · 0 invented entities

Abstract-only review limits visibility into explicit free parameters or invented entities. The method implicitly relies on the squared-loss assumption and on the validity of the NTK linearization for the backbone dynamics.

axioms (2)
  • domain assumption Squared loss admits a closed-form optimum for the final linear layer
    Stated directly in the abstract as the starting point for the method.
  • domain assumption Neural tangent kernel regime approximates the training dynamics of the backbone
    Invoked for the convergence analysis.

pith-pipeline@v0.9.0 · 5691 in / 1381 out tokens · 26229 ms · 2026-05-18T09:54:56.628527+00:00 · methodology

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Lean theorems connected to this paper

Citations machine-checked in the Pith Canon. Every link opens the source theorem in the public Lean library.

  • IndisputableMonolith/Cost/FunctionalEquation.lean washburn_uniqueness_aczel unclear
    ?
    unclear

    Relation between the paper passage and the cited Recognition theorem.

    We propose to leverage this during optimization, treating the last layer as a function of the backbone parameters, and optimizing solely for these parameters. We show this is equivalent to alternating between gradient descent steps on the backbone and closed-form updates on the last layer.

  • IndisputableMonolith/Foundation/BranchSelection.lean branch_selection unclear
    ?
    unclear

    Relation between the paper passage and the cited Recognition theorem.

    We adapt the method for the setting of stochastic gradient descent, by trading off the loss on the current batch against the accumulated information from previous batches.

What do these tags mean?
matches
The paper's claim is directly supported by a theorem in the formal canon.
supports
The theorem supports part of the paper's argument, but the paper may add assumptions or extra steps.
extends
The paper goes beyond the formal theorem; the theorem is a base layer rather than the whole result.
uses
The paper appears to rely on the theorem as machinery.
contradicts
The paper's claim conflicts with a theorem or certificate in the canon.
unclear
Pith found a possible connection, but the passage is too broad, indirect, or ambiguous to say the theorem truly supports the claim.

Forward citations

Cited by 3 Pith papers

Reviewed papers in the Pith corpus that reference this work. Sorted by Pith novelty score.

  1. Doubly Robust Proxy Causal Learning with Neural Mean Embeddings

    cs.LG 2026-05 unverdicted novelty 6.0

    A neural doubly robust proxy causal learning framework using mean embeddings for treatment bridges provides consistent estimators for causal dose-response functions under unobserved confounding for continuous and stru...

  2. Rethinking Neural Network Learning Rates: A Stackelberg Perspective

    cs.LG 2026-05 unverdicted novelty 5.0

    Non-uniform learning rates correspond to a Stackelberg reformulation of the training objective whose two-time-scale alternating gradient descent yields finite-time convergence and can accelerate training through stron...

  3. Context-Aware Model Predictive Control for Microgrid Energy Management via LLMs

    eess.SY 2025-12 unverdicted novelty 5.0

    InstructMPC uses an LLM plus tunable last layer to map operational context to disturbance trajectories for MPC, proving an O(sqrt(T log T)) regret bound for linear systems and showing lower grid costs on the OpenCEM m...

Reference graph

Works this paper leans on

22 extracted references · 22 canonical work pages · cited by 3 Pith papers · 1 internal anchor

  1. [1]

    Ultra-fast fea- ture learning for the training of two-layer neural net- works in the two-timescale regime.arXiv preprint arXiv:2504.18208,

    arXiv:2504.18208 [cs]. Rapha¨el Berthier, Andrea Montanari, and Kangjie Zhou. Learning Time-Scales in Two- Layers Neural Networks.Foundations of Computational Mathematics, August

  2. [2]

    doi: 10.1007/s10208-024-09664-9

    ISSN 1615-3383. doi: 10.1007/s10208-024-09664-9. URLhttps://doi.org/10.1007/ s10208-024-09664-9. Alberto Bietti, Joan Bruna, and Loucas Pillaud-Vivien. On learning Gaussian multi-index models with gradient flow part I: General properties and two-timescale learning.Communications on Pure and Applied Mathematics, n/a,

  3. [3]

    doi: 10.1002/cpa.70006

    ISSN 1097-0312. doi: 10.1002/cpa.70006. Andrew Brock, Soham De, Samuel L. Smith, and Karen Simonyan. High-performance large-scale image recognition without normalization,

  4. [4]

    James Harrison, John Willes, and Jasper Snoek

    URL https://proceedings.neurips.cc/paper_files/paper/2023/hash/ ac24656b0b5f543b202f748d62041637-Abstract-Conference.html. James Harrison, John Willes, and Jasper Snoek. Variational Bayesian Last Layers. October

  5. [5]

    Deep residual learning for image recognition

    doi: 10.1109/CVPR.2016.90. Like Hui and Mikhail Belkin. Evaluation of neural architectures trained with square loss vs cross- entropy in classifications tasks. InInternational Conference on Learning Representations,

  6. [6]

    Felix Koehler

    URLhttps://papers.nips.cc/paper_files/ paper/2018/hash/5a4be1fa34e62bb8a6ec6b91d2462f5a-Abstract.html. Felix Koehler. Machine Learning and Simulation,

  7. [7]

    toronto.edu/˜kriz/learning-features-2009-TR.pdf

    URLhttps://www.cs. toronto.edu/˜kriz/learning-features-2009-TR.pdf. Zongyi Li, Nikola Borislavov Kovachki, Kamyar Azizzadenesheli, Burigede Liu, Kaushik Bhat- tacharya, Andrew M. Stuart, and Anima Anandkumar. Fourier neural operator for parametric partial differential equations. In9th International Conference on Learning Representations, ICLR 2021, Virtua...

  8. [8]

    URLhttps://proceedings.neurips.cc/paper_files/paper/ 2023/hash/cd062f8003e38f55dcb93df55b2683d6-Abstract-Conference. html. Loic Matthey, Irina Higgins, Demis Hassabis, and Alexander Lerchner. dsprites: Disentanglement testing sprites dataset. https://github.com/deepmind/dsprites-dataset/,

  9. [9]

    Bharat Singh, Soham De, Yangmuzi Zhang, Thomas Goldstein, and Gavin Taylor

    URLhttps://proceedings.neurips.cc/paper_files/paper/2024/ hash/19ae2b95d3831c14373271112f189a22-Abstract-Conference.html. Bharat Singh, Soham De, Yangmuzi Zhang, Thomas Goldstein, and Gavin Taylor. Layer-Specific Adaptive Learning Rates for Deep Networks. In2015 IEEE 14th International Conference on Machine Learning and Applications (ICMLA), pp. 364–368, December

  10. [10]

    URLhttps://ieeexplore.ieee.org/document/7424337

    doi: 10.1109/ ICMLA.2015.113. URLhttps://ieeexplore.ieee.org/document/7424337. Simo S ¨arkk¨a.Bayesian Filtering and Smoothing. Institute of Mathematical Statis- tics Textbooks. Cambridge University Press, Cambridge,

  11. [11]

    doi: 10.1017/CBO9781139344203

    ISBN 978- 1-107-03065-7. doi: 10.1017/CBO9781139344203. URLhttps://www. cambridge.org/core/books/bayesian-filtering-and-smoothing/ C372FB31C5D9A100F8476C1B23721A67. Shokichi Takakura and Taiji Suzuki. Mean-field Analysis on Two-layer Neural Networks from a Kernel Perspective, April

  12. [12]

    Michalis Titsias, Alexandre Galashov, Amal Rannen-Triki, Razvan Pascanu, Yee Whye Teh, and Jorg Bornschein

    arXiv:2403.14917 [cs]. Michalis Titsias, Alexandre Galashov, Amal Rannen-Triki, Razvan Pascanu, Yee Whye Teh, and Jorg Bornschein. Kalman filter for online classification of non-stationary data. InThe Twelfth International Conference on Learning Representations,

  13. [13]

    Liyuan Xu, Yutian Chen, Siddarth Srinivasan, Nando de Freitas, Arnaud Doucet, and Arthur Gretton

    URL https://proceedings.neurips.cc/paper_files/paper/2024/hash/ 3e0f495e21bdbdb4251792d0fff57928-Abstract-Conference.html. Liyuan Xu, Yutian Chen, Siddarth Srinivasan, Nando de Freitas, Arnaud Doucet, and Arthur Gretton. Learning Deep Features in Instrumental Variable Regression. October

  14. [14]

    Liyuan Xu, Heishiro Kanagawa, and Arthur Gretton

    URL https://openreview.net/forum?id=sy4Kg_ZQmS7. Liyuan Xu, Heishiro Kanagawa, and Arthur Gretton. Deep proxy causal learning and its application to confounded bandit policy evaluation. InProceedings of the 35th International Conference on Neural Information Processing Systems, NIPS ’21, Red Hook, NY , USA, 2021a. Curran Asso- ciates Inc. ISBN 97817138453...

  15. [15]

    Large Batch Training of Convolutional Networks

    URLhttp://arxiv.org/abs/1708.03888. arXiv:1708.03888 [cs]. Yihua Zhang, Prashant Khanduri, Ioannis Tsaknakis, Yuguang Yao, Mingyi Hong, and Sijia Liu. An Introduction to Bilevel Optimization: Foundations and applications in signal processing and machine learning.IEEE Signal Processing Magazine, 41(1):38–59, April

  16. [16]

    doi: 10.1109/MSP.2024.3358284

    ISSN 1558-0792. doi: 10.1109/MSP.2024.3358284. 12 Preprint A Kalman Filter Interpretation of the Proximal Algorithm 13 B Alternative Algorithm 14 C Proofs for the Theoretical Analysis of the Loss 14 C.1 Neural Tangent Kernel Infinite Width Limit . . . . . . . . . . . . . . . . . . . . . 15 D Deep Feature Instrumental Variable Regression 17 E Experimental ...

  17. [17]

    for a similar discussion. First, we assume that the model fits the data perfectly during optimization, so we have the likelihood p(yi |x i, Wt, θt) =N(y i |W tϕθt(xi), σ2 Y I)(22) whereIis the identity matrix andσ 2 Y is some hyperparameter controlling the variance of the outputs and(x i, yi)∈ B t. Next, we assumeW t evolves like a random walk with Gaussi...

  18. [18]

    But this is impossible since we showed thatL ⋆ F is not convex

    Remark that a result such as Theorem 1 cannot be extended to second derivatives, otherwise we could differentiateL ⋆ F twice by keepingW ⋆ F(ϕ)constant, and would obtain that the Hessian ofL ⋆ F is positive semi-definite since so is the one of the squared loss. But this is impossible since we showed thatL ⋆ F is not convex. C.1 NEURALTANGENTKERNELINFINITE...

  19. [19]

    This is an image dataset described by five latent parameters(shape,scale,rotation,posX,posY)

    and we consider a slightly modified version ofd-spirtestask (Matthey et al., 2017). This is an image dataset described by five latent parameters(shape,scale,rotation,posX,posY). The images are 64×64 = 4096dimensional. In this experiment, the authors fix theshapeparameter toheart, i.e., they only used heart-shaped images. The authors generated data for IV ...

  20. [20]

    (2020) led to essentially a constant function (in expectation)

    The choice of this structural function was motivated by (Xu et al., 2021a), because the original choice described in Xu et al. (2020) led to essentially a constant function (in expectation). For our experiments, we use different batch sizes. The DFIV method (Xu et al.,

  21. [21]

    ℓ 2 c.f. ridge (β)

    essentially corresponds to two-stage “ℓ 2 c.f. ridge (β)” where we haveβ 1 andβ 2 parameters for the first and second stage correspondingly. In our proximal method, DFIV proximal, as described in Section D, we have three parametersλ 1,λ 2 for the first and second stage proximal updates andλ 1,2 for first- stage update inside the second stage. In practice,...

  22. [22]

    ℓ 2 loss

    and whenever performance is reported, takes the first stage and second stage backbone parameters, and re-estimates the corresponding last layers on the whole10000training set. In Figure 3 it is repre- sented by the solid line. The second strategy just takes the current estimates of the last layers. In Figure 3 it is represented by the dashed line. The swe...