pith. sign in

arxiv: 2604.25409 · v1 · submitted 2026-04-28 · 💻 cs.CL

Scaling Probabilistic Transformer via Efficient Cross-Scale Hyperparameter Transfer

Pith reviewed 2026-05-07 16:26 UTC · model grok-4.3

classification 💻 cs.CL
keywords Probabilistic TransformerMaximal Update Parametrizationhyperparameter transfermodel scalingmasked language modelingtransformer alternativeswhite-box probabilistic model
0
0 comments X

The pith

Applying Maximal Update Parametrization lets Probabilistic Transformers scale to 0.4B parameters by transferring hyperparameters from small models, consistently outperforming standard Transformers on masked language modeling with equal size

A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.

The paper shows how Probabilistic Transformers, which are white-box probabilistic models for word representations, can overcome their hyperparameter sensitivity through Maximal Update Parametrization. By rescaling parameters according to muP rules, hyperparameters tuned on small models transfer directly to versions with up to 0.4 billion parameters. Experiments confirm these larger PT models beat standard Transformers on masked language modeling tasks when both use the same parameter budget. Readers care because this removes a practical barrier that had kept probabilistic alternatives from competing at scale. The work demonstrates that principled parameter adjustments can make non-standard architectures trainable at larger sizes without repeated retuning.

Core claim

By following Maximal Update Parametrization (muP) to rescale PT's parameters, hyperparameters optimized on small models transfer to larger models up to 0.4B parameters without additional tuning. Experiments show that PT consistently outperforms standard transformer under the same parameter budget on Masked Language Modeling (MLM) tasks.

What carries the argument

Maximal Update Parametrization (muP) applied to rescale the parameters of the Probabilistic Transformer, which stabilizes training dynamics and enables direct hyperparameter transfer across model sizes.

Load-bearing premise

That muP rescaling can be applied to PT's parameters to achieve stable hyperparameter transfer without degrading the model's probabilistic properties or requiring architecture-specific adjustments beyond rescaling.

What would settle it

Train a 0.4B-parameter PT model using hyperparameters transferred from a small model but without muP rescaling, then check whether it fails to train stably or loses its performance advantage over a standard Transformer of equal size.

Figures

Figures reproduced from arXiv: 2604.25409 by Haoyi Wu, Kewei Tu, Penghao Kuang.

Figure 1
Figure 1. Figure 1: Performance comparison between perturbed sampling points and benchmark points. view at source ↗
Figure 2
Figure 2. Figure 2: Performance comparison between PT and BERT, and between PT and Universal Transformer across view at source ↗
read the original abstract

Probabilistic Transformer (PT), a white-box probabilistic model for contextual word representation, has demonstrated substantial similarity to standard Transformers in both computational structure and downstream task performance on small models and small to medium sized datasets. However, PT is less robust to hyperparameter choices than standard Transformers, making it harder to scale efficiently. In this work, we follow Maximal Update Parametrization (muP) to rescale PT's parameters, so that hyperparameters optimized on small models can be transferred to larger models without additional tuning. With this approach, we successfully scale PT to models with up to 0.4B parameters. Experiments show that PT consistently outperforms standard transformer under the same parameter budget on Masked Language Modeling (MLM) tasks. We hope this work will contribute to the practical deployment of probabilistic models at substantially larger scales in the future.

Editorial analysis

A structured set of objections, weighed in public.

Desk editor's note, referee report, simulated authors' rebuttal, and a circularity audit. Tearing a paper down is the easy half of reading it; the pith above is the substance, this is the friction.

Referee Report

0 major / 1 minor

Summary. The paper claims that applying Maximal Update Parametrization (muP) to rescale the parameters of the Probabilistic Transformer (PT) enables stable hyperparameter transfer from small to large models. This approach allows scaling PT to models with up to 0.4B parameters. Experiments show that the resulting PT models consistently outperform standard Transformers on Masked Language Modeling (MLM) tasks under matched parameter budgets.

Significance. If the empirical results hold, the work provides a practical route to scale white-box probabilistic models, which could offer interpretability advantages over standard Transformers at larger scales. The explicit use of muP for cross-scale transfer is a strength, as it leverages prior literature to avoid extensive retuning and supports reproducible scaling experiments.

minor comments (1)
  1. [Abstract and Experiments] The abstract reports successful scaling and outperformance but provides no details on experimental setup, baselines, error bars, or statistical significance; the full paper should expand on these in the experiments section to allow proper evaluation of the central performance claim.

Simulated Author's Rebuttal

0 responses · 0 unresolved

We thank the referee for the positive summary, recognition of the significance of applying muP to Probabilistic Transformers, and the recommendation for minor revision. We are pleased that the work's potential for scaling white-box models is acknowledged.

Circularity Check

0 steps flagged

No significant circularity detected

full rationale

The paper applies the externally established Maximal Update Parametrization (muP) to rescale Probabilistic Transformer parameters, enabling hyperparameter transfer, then reports direct empirical comparisons of scaled PT versus standard Transformers on MLM tasks at matched budgets up to 0.4B parameters. No load-bearing step reduces by construction to a self-definition, a fitted input renamed as prediction, or a self-citation chain; the derivation relies on prior independent literature for muP and on experimental outcomes for the performance claim, making the chain self-contained.

Axiom & Free-Parameter Ledger

0 free parameters · 0 axioms · 0 invented entities

The work rests on the transferability of muP from standard transformers to PT without new free parameters introduced in this paper; no invented entities or ad-hoc axioms are stated in the abstract.

pith-pipeline@v0.9.0 · 5437 in / 987 out tokens · 50200 ms · 2026-05-07T16:26:04.402591+00:00 · methodology

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Reference graph

Works this paper leans on

21 extracted references · 1 canonical work pages

  1. [1]

    David M Blei, Andrew Y Ng, and Michael I Jordan

    Variational inference: A review for statisti- cians.Journal of the American statistical Associa- tion, 112(518):859–877. David M Blei, Andrew Y Ng, and Michael I Jordan

  2. [2]

    Tri Dao, Daniel Fu, Stefano Ermon, Atri Rudra, and Christopher Ré

    Latent dirichlet allocation.Journal of machine Learning research, 3(Jan):993–1022. Tri Dao, Daniel Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. 2022. Flashattention: Fast and memory-efficient exact attention with io-awareness. InAdvances in Neural Information Processing Sys- tems, volume 35, pages 10428–10440. M. Dehghani, S. Gouws, O. Vinyals, J. U...

  3. [3]

    InProceedings of the 2019 Conference of the North American Chap- ter of the Association for Computational Linguistics (NAACL-HLT)

    BERT: Pre-training of deep bidirectional trans- formers for language understanding. InProceedings of the 2019 Conference of the North American Chap- ter of the Association for Computational Linguistics (NAACL-HLT). Jean Kaddour, Joshua Harris, Maximilian Mozes, Her- bie Bradley, Roberta Raileanu, and Robert McHardy

  4. [4]

    The minipile challenge for data-efficient language models.arXiv preprint arXiv:2304.08442, 2023

    The minipile challenge for data-efficient lan- guage models.arXiv preprint arXiv:2304.08442. John D Lafferty, Andrew McCallum, and Fernando CN Pereira. 2001. Conditional random fields: Proba- bilistic models for segmenting and labeling sequence data. InProceedings of the Eighteenth International Conference on Machine Learning. H. Ramsauer, B. Schäfl, J. L...

  5. [5]

    In Advances in Neural Information Processing Systems (NeurIPS)

    Scaling white-box transformers for vision. In Advances in Neural Information Processing Systems (NeurIPS). A Equivalence Between Modified Variational Inference and Mathematical Essence This section re-derives the Mean Field Variational Inference (MFVI) process to demonstrate that the reconstruction of potential functions and varia- tional free energy math...

  6. [6]

    Ternary Factor: ϕt(Hi =j, Z i =a, Z j = b) = exp(τ N T(c) a,b )(ifH i =j, otherwise 1)

  7. [7]

    attention fea- ture dimension

    Binary Factor: ϕb(Zi =a, G i =g) = exp(τ M Bg,a) Here, T (c) takes the low-rank decomposition form U (c)V (c)⊤, N is the hidden state dimension, r is the rank of the ternary factor, and M is the dimension of the binary factor. Given a sequence w, the joint probability distribution of the system P(Z, H, G|w)is defined as: P(Z, H, G|w) = 1 Z nY i=1 ϕu(Zi) n...

  8. [8]

    Dimension-Adaptive Quasi-Distribution Scaling: The binary and ternary terms in the update equation automatically carry their corresponding dimension scaling factors (i.e., ˜Qz =N Q z and ˜QG =M Q G), precisely con- verting probabilities into feature vectors with a constant coordinate intensity ofΘ(1)

  9. [9]

    Absolute Scale Invariance: The magnitude of the Logits is entirely determined by the parameter matrices (e.g., S, T, B) and is fully decoupled from the model widthsNandM. B Unified Adaptability of Potential Function and Variational Free Energy Modifications in the Head Selection Module In the Head Selection module of PT, due to the expansion of the model ...

  10. [10]

    Magnitude of the Entropy Function For the entropy function H(Q) =− NX a=1 Q(a) lnQ(a)(25) we consider its evolution in stages. Early Training Stage (Uniform Distribution) At this point, H(Q) = NX a=1 1 N lnN= lnN(26) Combined with the setting τ= Θ(N) , its overall magnitude is τ H(Q) = Θ(NlnN)(27) Convergence Stage (Confident Distribution) At this point, ...

  11. [11]

    Early Training StageThe N individual terms of parameter S can be viewed as mutually inde- pendent, with their sum exhibiting a random walk behavior

    Magnitude of the Unary Potential Term For the unary potential term Eunary =− τ N NX a=1 ˜Qi(a)Swi,a (29) the derivation is as follows. Early Training StageThe N individual terms of parameter S can be viewed as mutually inde- pendent, with their sum exhibiting a random walk behavior. The variance is Var( X ˜QS) =N·Var( ˜QS) = Θ(N)(30) making the sum of abs...

  12. [12]

    Magnitude of the Binary Potential Term For the binary potential term Ebinary =− τ N NX a=1 ˜Qi(a) MX g=1 ˜QG i (g)Bg,a (34) the derivation is as follows. Early Training StageAccording to the µP ini- tialization principles, Bg,a = Θ(1/ √ N)(35) Independent random walks lead the inner sum to Θ( p M/N) = Θ(1) (since M∝N ), and the random walk of the total su...

  13. [13]

    Convergence StageUtilizing the low-rank de- composition Ta,b = rX l=1 Ua,lVb,l (41) Because r= Θ(1) , the aligned learned parts are U learn, V learn = Θ(1/N)

    Magnitude of the Ternary Potential Term For the single-channel energy E(c) ternary =− τ N X i,j Qic(j) X a,b ˜Qi(a) ˜Qj(b)T (c) a,b (40) the derivation is as follows. Convergence StageUtilizing the low-rank de- composition Ta,b = rX l=1 Ua,lVb,l (41) Because r= Θ(1) , the aligned learned parts are U learn, V learn = Θ(1/N). After the double super- positio...

  14. [14]

    Magnitude of the Entropy Function Based on the settingτ= Θ(1): Early Training StageThe overall magnitude of the entropy term is τ H(Q) = Θ(1)·lnN= Θ(lnN)(43) Convergence StageThe overall magnitude of the entropy term is τ H(Q) = Θ(1)·Θ(1) = Θ(1)(44)

  15. [15]

    Substituting the new coefficient τ N = Θ(1) N (45) we obtain Eunary,binary = Θ( 1√ N )≈0(46) Convergence StageThe original aligned sum of energies is Θ(N)

    Magnitude of the Energy Function (Unary and Binary Potentials) Early Training StageFollowing the derivation in Paradigm 1, the unscaled original sum of ener- gies isΘ( √ N). Substituting the new coefficient τ N = Θ(1) N (45) we obtain Eunary,binary = Θ( 1√ N )≈0(46) Convergence StageThe original aligned sum of energies is Θ(N). Substituting the new coeffi...

  16. [16]

    Magnitude of the Energy Function (Ternary Potential) Convergence StageSince r∝N , the learned part of the transition matrix is T learn a,b = rX l=1 Ua,lVb,l ∼r·( 1 N · 1 N ) = Θ( 1 N ) (48) The double summation PN a,b=1 ˜Q(a) ˜Q(b)Ta,b is equivalent to the coherent superposition of N 2 terms, yielding an original magnitude of N 2 ·Θ( 1 N ) = Θ(N)(49) Appl...

  17. [17]

    These parameters map discrete or low- dimensional inputs to a high-dimensional space

    Input Group: Primarily includes the unary potential matrix S∈R V×N and all bi- ases. These parameters map discrete or low- dimensional inputs to a high-dimensional space. • Initialization: Set standard deviation σin = Θ(1). • Learning Rate: Set scaling factor ηin_mult = 1

  18. [18]

    These pa- rameters dictate interactions between high- dimensional representations

    Hidden Group: Includes the low-rank de- composition matrices U, V∈R N×r of the ternary factor and the binary factor matrix B∈R M×N (where M∝N ). These pa- rameters dictate interactions between high- dimensional representations. • Initialization: Set standard deviation σhid = Θ(1/ √ N). • Learning Rate: Set scaling factor ηhid_mult = 1/N

  19. [19]

    These parameters project high-dimensional represen- tations back to scalar scores

    Output Group: Refers to the decoding matrix Wout ∈R N×V in the prediction head. These parameters project high-dimensional represen- tations back to scalar scores. • Initialization: Set standard deviation σout = Θ(1/N). • Learning Rate: Set scaling factor ηout_mult = 1. C.2 Principle 1: Coordinate Magnitude Stability (Forward Pass Stability) Principle 1 Re...

  20. [20]

    Hidden Layer Parameter Updates: For the hidden layer, the change in its activations is: ∆yi = NX j=1 ∆Wijxj ≈η hid NX j=1 sign(gij)xj (57) 11 Due to the coherent superposition effect, the magnitude of the summation term reaches Θ(N). To achieve∆yi = Θ(1), the following condition must be met: ηhid =η base ·η hid_mult =η base · 1 N (58) This explains why th...

  21. [21]

    Thus, keeping ηin_mult = 1 is sufficient to achieve an update of Θ(1)

    Input and Output Layer Updates: The up- dates to the input layer (unary potential terms) do not involve linear accumulation across di- mensions; their magnitude of change is di- rectly determined by ηin. Thus, keeping ηin_mult = 1 is sufficient to achieve an update of Θ(1). For the output layer, to counteract the 1/N suppression at initialization and en- ...