GenSBI: Generative Methods for Simulation-Based Inference in JAX
Pith reviewed 2026-06-29 18:42 UTC · model grok-4.3
The pith
GenSBI supplies a JAX library that runs flow matching, score matching, and denoising diffusion for simulation-based inference.
A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.
Core claim
GenSBI implements flow matching, score matching, and denoising diffusion entirely in JAX through three transformer architectures—SimFormer, Flux1, and the novel Flux1Joint—delivered via a single interface that decouples generative method, neural backbone, and inference mode, and reports near-ideal mean C2ST scores on SBIBM tasks with minimal per-task tuning.
What carries the argument
The unified interface that decouples generative method from neural backbone from inference mode, together with the gate-modulated transformer blocks extended to joint density estimation.
If this is right
- JAX users can train and deploy generative SBI models without switching frameworks for the inference step.
- The same code base supports posterior, likelihood, and joint-density estimation by swapping only the inference mode flag.
- Custom domain-specific embedding networks can be dropped in without rewriting the generative training loop.
- Built-in SBC, TARP, and LC2ST routines allow immediate checking of posterior calibration after training.
Where Pith is reading between the lines
- The availability of native JAX implementations may lower the barrier for teams whose forward simulators are already written in JAX or JAX-based autodiff ecosystems.
- Because the architectures are interchangeable, the library could serve as a testbed for comparing how different transformer variants affect calibration across SBI tasks.
- Extending the Flux1Joint block to additional modalities or higher-dimensional observations would be a direct next step permitted by the modular design.
Load-bearing premise
The JAX code produces the same numerical behavior and optimization targets as the original PyTorch reference implementations of flow matching, score matching, and denoising diffusion.
What would settle it
Running identical SBIBM tasks in both GenSBI and an established PyTorch SBI library and checking whether the C2ST scores and posterior-coverage diagnostics match within sampling noise.
Figures
read the original abstract
Flow and diffusion generative models have established themselves as widely adopted density estimators for simulation-based inference (SBI), extending naturally from neural posterior estimation to likelihood and joint density estimation. Their principled optimization objectives and freedom from architectural constraints have driven rapid adoption across the natural sciences. Yet the most widely used SBI libraries remain PyTorch-based, leaving researchers who develop their forward models and analysis pipelines in JAX without a native option. We present GenSBI, an open-source library that implements flow matching, score matching, and denoising diffusion entirely in JAX. The library offers three transformer-based architectures - SimFormer, Flux1, and a novel Flux1Joint that extends gate-modulated transformer blocks to joint density estimation - all interchangeable through a unified interface that decouples generative method, neural backbone, and inference mode. GenSBI provides an end-to-end workflow from training through posterior calibration (SBC, TARP, LC2ST) and supports custom architectures with domain-specific embedding networks. We validate the framework on standard SBI benchmarks, achieving near-ideal mean C2ST scores (0.50-0.56, where 0.50 is ideal) on SBIBM tasks with minimal per-task tuning and well-calibrated posterior coverage across all tested configurations. The code is publicly available at https://github.com/aurelio-amerio/GenSBI.
Editorial analysis
A structured set of objections, weighed in public.
Referee Report
Summary. The manuscript presents GenSBI, an open-source JAX library implementing flow matching, score matching, and denoising diffusion for simulation-based inference. It introduces three interchangeable transformer backbones (SimFormer, Flux1, Flux1Joint) with a unified interface, supports custom embeddings, and provides end-to-end workflows including posterior calibration diagnostics. Validation on SBIBM tasks reports near-ideal mean C2ST scores (0.50-0.56) with minimal per-task tuning and well-calibrated coverage.
Significance. If the implementations are correct, the library fills a clear gap by providing native JAX support for generative SBI methods, allowing seamless use with JAX-based simulators. The modular design decoupling method, backbone, and inference mode, plus explicit support for SBC/TARP/LC2ST calibration, represents a practical contribution. The reported benchmark performance, if verified, would demonstrate utility with limited tuning.
major comments (1)
- [Experiments] Experiments section (and abstract performance claims): The reported C2ST scores (0.50-0.56) and calibration metrics are presented as evidence of correct implementation, yet no explicit verification is described (e.g., loss-value matching, gradient agreement, or sampling-distribution equivalence tests) against established PyTorch reference implementations of flow/score matching and diffusion. This verification is load-bearing for the central claim that the JAX code faithfully realizes the intended generative objectives.
minor comments (2)
- The abstract and §3 could clarify the precise SBIBM task suite and version used, as well as the hyperparameter search ranges that support the 'minimal per-task tuning' claim.
- [§3] Notation for the Flux1Joint architecture (gate-modulated blocks for joint density) should be defined explicitly in §3 before use in experiments.
Simulated Author's Rebuttal
We thank the referee for highlighting the importance of implementation verification. We address the single major comment below and will revise the manuscript to strengthen this aspect.
read point-by-point responses
-
Referee: [Experiments] Experiments section (and abstract performance claims): The reported C2ST scores (0.50-0.56) and calibration metrics are presented as evidence of correct implementation, yet no explicit verification is described (e.g., loss-value matching, gradient agreement, or sampling-distribution equivalence tests) against established PyTorch reference implementations of flow/score matching and diffusion. This verification is load-bearing for the central claim that the JAX code faithfully realizes the intended generative objectives.
Authors: We agree that explicit cross-framework verification would provide stronger evidence for faithful implementation of the generative objectives. The current validation demonstrates that the reported C2ST scores (0.50-0.56) and calibration metrics match the expected near-ideal performance on SBIBM tasks, which would be improbable under significant implementation deviations. However, we acknowledge the referee's point that this constitutes indirect rather than direct evidence. In the revised manuscript we will add a dedicated subsection under Experiments that reports (i) training-loss agreement on shared tasks with available PyTorch references, (ii) gradient-norm comparisons where architectures permit, and (iii) distributional equivalence checks via Kolmogorov-Smirnov tests on posterior samples for at least the flow-matching and score-matching backbones. For denoising diffusion we will note the absence of a directly comparable open-source PyTorch SBI reference and rely on the calibration diagnostics already presented. revision: yes
Circularity Check
No circularity: performance metrics drawn from external SBIBM benchmarks, not self-defined quantities
full rationale
The paper describes a JAX library implementing established methods (flow matching, score matching, denoising diffusion) with transformer backbones and reports C2ST scores and calibration on SBIBM tasks. These benchmarks are independent external standards; the results are not predictions derived from the library's own fitted parameters or reduced by self-citation chains. No load-bearing self-definitional steps, fitted-input predictions, or ansatz smuggling appear in the abstract or described claims. The central validation rests on external data rather than internal redefinitions.
Axiom & Free-Parameter Ledger
Reference graph
Works this paper leans on
-
[1]
First results from the IllustrisTNG simulations: matter and galaxy clustering
V . Springel, R. Pakmor, A. Pillepich, R. Weinberger, D. Nelson, L. Hernquist et al.,First results from the IllustrisTNG simulations: matter and galaxy clustering,Monthly Notices of the Royal Astronomical Society475(2017) 676 [1707.03397]
work page internal anchor Pith review Pith/arXiv arXiv 2017
-
[2]
F. Villaescusa-Navarro, D. Anglés-Alcázar, S. Genel, D.N. Spergel, R.S. Somerville, R. Dave et al.,The CAMELS Project: Cosmology and Astrophysics with Machine-learning Simulations, The Astrophysical Journal915(2021) 71 [2010.00619]
-
[3]
K. Cranmer, J. Brehmer and G. Louppe,The frontier of simulation-based inference, Proceedings of the National Academy of Sciences117(2020) 30055 [1911.01429]
- [4]
- [5]
-
[6]
Carr, M.J
M.J. Carr, M.J. Simpson and C. Drovandi,Estimating parameters of a stochastic cell invasion model with fluorescent cell cycle labelling using approximate Bayesian computation,J. R. Soc. Interface18(2021) 20210362
2021
- [7]
-
[8]
P.L.C. Rodrigues, T. Moreau, G. Louppe and A. Gramfort,HNPE: Leveraging Global Parameters for Neural Posterior Estimation,2102.06477
-
[9]
Hashemi, A.N
M. Hashemi, A.N. Vattikonda, J. Jha, V . Sip, M.M. Woodman, F. Bartolomei et al.,Amortized Bayesian inference on generative dynamical network models of epilepsy using deep neural density estimators,Neural Networks163(2023) 178
2023
-
[10]
S.T. Radev, F. Graw, S. Chen, N.T. Mutters, V .M. Eichel, T. Bärnighausen et al.,OutbreakFlow: Model-based Bayesian inference of disease outbreak dynamics with invertible neural networks and its application to the COVID-19 pandemics in Germany,PLOS Comput. Biol.17(2021) e1009472 [2010.00300]
-
[11]
Metropolis, A.W
N. Metropolis, A.W. Rosenbluth, M.N. Rosenbluth, A.H. Teller and E. Teller,Equation of state calculations by fast computing machines,The Journal of Chemical Physics21(1953) 1087
1953
-
[12]
Hastings,Monte Carlo Sampling Methods Using Markov Chains and Their Applications, Biometrika57(1970) 97
W.K. Hastings,Monte Carlo Sampling Methods Using Markov Chains and Their Applications, Biometrika57(1970) 97
1970
-
[13]
Duane, A.D
S. Duane, A.D. Kennedy, B.J. Pendleton and D. Roweth,Hybrid Monte Carlo,Physics Letters. B. Particle Physics, Nuclear Physics and Cosmology195(1987) 216
1987
-
[14]
Microcanonical Hamiltonian Monte Carlo
J. Robnik, G.B. De Luca, E. Silverstein and U. Seljak,Microcanonical Hamiltonian Monte Carlo,2212.08549
work page internal anchor Pith review Pith/arXiv arXiv
-
[15]
Skilling,Nested sampling for general Bayesian computation,Bayesian Analysis1(2006) 833
J. Skilling,Nested sampling for general Bayesian computation,Bayesian Analysis1(2006) 833. 48
2006
-
[16]
Sisson, Y
S.A. Sisson, Y . Fan and M.A. Beaumont,Overview of Approximate Bayesian Computation,
-
[17]
10.1201/9781315117195-1
-
[18]
Wood,Statistical inference for noisy nonlinear ecological dynamic systems,Nature466 (2010) 1102
S.N. Wood,Statistical inference for noisy nonlinear ecological dynamic systems,Nature466 (2010) 1102
2010
-
[19]
Automatic Posterior Transformation for Likelihood-Free Inference
D.S. Greenberg, M. Nonnenmacher and J.H. Macke,Automatic Posterior Transformation for Likelihood-Free Inference, inInternational Conference on Machine Learning, 2019, DOI [1905.07488]
work page internal anchor Pith review Pith/arXiv arXiv 2019
-
[20]
S.T. Radev, U.K. Mertens, A. V oss, L. Ardizzone and U. Köthe,BayesFlow: Learning complex stochastic models with invertible neural networks, inIEEE Transactions on Neural Networks and Learning Systems, 2020, DOI [2003.06281]
-
[21]
Fast $\epsilon$-free Inference of Simulation Models with Bayesian Conditional Density Estimation
G. Papamakarios and I. Murray,Fastϵ-free Inference of Simulation Models with Bayesian Conditional Density Estimation, inAdvances in Neural Information Processing Systems, vol. 29, 2016, DOI [1605.06376]
work page internal anchor Pith review Pith/arXiv arXiv 2016
-
[22]
Sequential Neural Likelihood: Fast Likelihood-free Inference with Autoregressive Flows
G. Papamakarios, D.C. Sterratt and I. Murray,Sequential Neural Likelihood: Fast Likelihood-free Inference with Autoregressive Flows,1805.07226
work page internal anchor Pith review Pith/arXiv arXiv
-
[23]
M. Deistler, J. Boelts, P. Steinbach, G. Moss, T. Moreau, M. Gloeckler et al.,Simulation-Based Inference: A Practical Guide,2508.12939
-
[24]
J. Hermans, V . Begy and G. Louppe,Likelihood-free MCMC with Amortized Approximate Ratio Estimators,1903.04057
-
[25]
B. Uria, I. Murray and H. Larochelle,RNADE: The real-valued neural autoregressive density-estimator,Advances in Neural Information Processing Systems26(2013) [1306.0186]
work page internal anchor Pith review Pith/arXiv arXiv 2013
-
[26]
Neural Autoregressive Distribution Estimation
B. Uria, M.-A. Côté, K. Gregor, I. Murray and H. Larochelle,Neural autoregressive distribution estimation,Journal of Machine Learning Research17(2016) 1 [1605.02226]
work page internal anchor Pith review Pith/arXiv arXiv 2016
-
[27]
M. Gloeckler, M. Deistler and J.H. Macke,All-in-one simulation-based inference, in International Conference on Machine Learning, 2024, DOI [2404.09636]
-
[28]
J.-M. Lueckmann, J. Boelts, D.S. Greenberg, P.J. Gonçalves and J.H. Macke,Benchmarking Simulation-Based Inference, inInternational Conference on Artificial Intelligence and Statistics, 2021, DOI [2101.04653]
-
[29]
Boelts, M
J. Boelts, M. Deistler, M. Gloeckler, Álvaro Tejero-Cantero, J.-M. Lueckmann, G. Moss et al., sbi reloaded: a toolkit for simulation-based inference workflows,Journal of Open Source Software10(2025) 7754
2025
-
[30]
PyTorch: An Imperative Style, High-Performance Deep Learning Library
A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan et al.,PyTorch: An Imperative Style, High-Performance Deep Learning Library, inAdvances in Neural Information Processing Systems, vol. 32, 2019, DOI [1912.01703]
work page internal anchor Pith review Pith/arXiv arXiv 2019
-
[31]
Miller, A
B.K. Miller, A. Cole, C. Weniger, F. Nattino, O. Ku and M.W. Grootes,swyft: Truncated Marginal Neural Ratio Estimation in Python,Journal of Open Source Software7(2022) 4205
2022
-
[32]
TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis
Y . Lipman, R.T.Q. Chen, H. Ben-Hamu, M. Nickel and M. Le,Flow Matching for Generative Modeling, inInternational Conference on Learning Representations, 2023, DOI [2210.02186]
work page internal anchor Pith review Pith/arXiv arXiv 2023
-
[33]
Y . Song, J. Sohl-Dickstein, D.P. Kingma, A. Kumar, S. Ermon and B. Poole,Score-Based Generative Modeling through Stochastic Differential Equations,2011.13456
work page internal anchor Pith review Pith/arXiv arXiv 2011
-
[34]
Elucidating the Design Space of Diffusion-Based Generative Models
T. Karras, M. Aittala, T. Aila and S. Laine,Elucidating the Design Space of Diffusion-Based Generative Models,Advances in Neural Information Processing Systems 3535(2022) 26565 [2206.00364]
work page internal anchor Pith review Pith/arXiv arXiv 2022
- [35]
-
[36]
Dingeldein, P
L. Dingeldein, P. Cossio and R. Covino,Simulation-based inference of single-molecule experiments,Current Opinion in Structural Biology91(2025) 102988
2025
- [37]
- [38]
-
[39]
M. Nautiyal, L. Ju, M. Ernfors, K. Hagland, V . Holma, M. Werkö Söderholm et al., OneFlowSBI: One Model, Many Queries for Simulation-Based Inference,2601.22951
-
[40]
JAX: composable transformations of Python+NumPy programs
J. Bradbury, R. Frostig, P. Hawkins, M.J. Johnson, C. Leary, D. Maclaurin et al., “JAX: composable transformations of Python+NumPy programs.” http://github.com/jax-ml/jax, 2018
2018
-
[41]
A. Cabezas, A. Corenflos, J. Lao and R. Louf,BlackJAX: Composable Bayesian inference in JAX,2402.10797
-
[42]
D. Phan, N. Pradhan and M. Jankowiak,Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro,1912.11554
work page internal anchor Pith review Pith/arXiv arXiv 1912
-
[43]
S. Dirmeier, S. Ulzega, A. Mira and C. Albert,Simulation-based inference with the Python package sbijax,2409.19435
-
[44]
Flax: A neural network library and ecosystem for JAX
J. Heek, A. Levskaya, A. Oliver, M. Ritter, B. Rondepierre, A. Steiner et al., “Flax: A neural network library and ecosystem for JAX.”http://github.com/google/flax, 2024
2024
-
[45]
FLUX.1 Kontext: Flow Matching for In-Context Image Generation and Editing in Latent Space
Black Forest Labs, S. Batifol, A. Blattmann, F. Boesel, S. Consul, C. Diagne et al.,FLUX.1 Kontext: Flow Matching for In-Context Image Generation and Editing in Latent Space, 2506.15742
work page internal anchor Pith review Pith/arXiv arXiv
-
[46]
Diggle and R.J
P.J. Diggle and R.J. Gratton,Monte Carlo Methods of Inference for Implicit Statistical Models, Journal of the Royal Statistical Society: Series B (Methodological)46(1984) 193
1984
-
[47]
High-Resolution Image Synthesis with Latent Diffusion Models
R. Rombach, A. Blattmann, D. Lorenz, P. Esser and B. Ommer,High-Resolution Image Synthesis with Latent Diffusion Models, inIEEE/CVF Conference on Computer Vision and Pattern Recognition, 2022, DOI [2112.10752]
work page internal anchor Pith review Pith/arXiv arXiv 2022
-
[48]
Boelts, J.-M
J. Boelts, J.-M. Lueckmann, R. Gao and J.H. Macke,Flexible and efficient simulation-based inference for models of decision-making,eLife11(2022) e77220
2022
-
[49]
X. Ai, Y . He, A. Gu, R. Salakhutdinov, J.Z. Kolter, N.M. Boffi et al.,Joint Distillation for Fast Likelihood Evaluation and Sampling in Flow-based Models,2512.02636
work page internal anchor Pith review Pith/arXiv arXiv
-
[50]
K.H. Scheutwinkel, W. Handley, C. Weniger and E. de Lera Acedo,PolySwyft: sequential simulation-based nested sampling,2512.08316
-
[51]
Masked Autoregressive Flow for Density Estimation
G. Papamakarios, T. Pavlakou and I. Murray,Masked Autoregressive Flow for Density Estimation,1705.07057
work page internal anchor Pith review Pith/arXiv arXiv
- [52]
-
[53]
Y . Lipman, M. Havasi, P. Holderrieth, N. Shaul, M. Le, B. Karrer et al.,Flow Matching Guide and Code,2412.06264
work page internal anchor Pith review Pith/arXiv arXiv
-
[54]
Stochastic Interpolants: A Unifying Framework for Flows and Diffusions
M.S. Albergo, N.M. Boffi and E. Vanden-Eijnden,Stochastic Interpolants: A Unifying Framework for Flows and Diffusions,2303.08797
work page internal anchor Pith review Pith/arXiv arXiv
-
[55]
A. Zammit-Mangion, M. Sainsbury-Dale and R. Huser,Neural methods for amortized inference,Annual Review of Statistics and Its Application12(2025) 311 [2404.12484]
-
[56]
Deep Unsupervised Learning using Nonequilibrium Thermodynamics
J. Sohl-Dickstein, E.A. Weiss, N. Maheswaranathan and S. Ganguli,Deep Unsupervised Learning using Nonequilibrium Thermodynamics,1503.03585
work page internal anchor Pith review Pith/arXiv arXiv
-
[57]
J. Ho, A. Jain and P. Abbeel,Denoising Diffusion Probabilistic Models,2006.11239
work page internal anchor Pith review Pith/arXiv arXiv 2006
-
[58]
Generative Modeling by Estimating Gradients of the Data Distribution
Y . Song and S. Ermon,Generative Modeling by Estimating Gradients of the Data Distribution, inAdvances in Neural Information Processing Systems, vol. 32, 2019, DOI [1907.05600]
work page internal anchor Pith review Pith/arXiv arXiv 2019
-
[59]
Y . Song and S. Ermon,Improved Techniques for Training Score-Based Generative Models, in Advances in Neural Information Processing Systems, vol. 33, 2020, DOI [2006.09011]
-
[60]
Hyvärinen, J
A. Hyvärinen, J. Hurri and P.O. Hoyer,Estimation of Non-Normalized Statistical Models by Score Matching,Journal of Machine Learning Research6(2005) 695
2005
-
[61]
Vincent,A connection between score matching and denoising autoencoders,Neural Computation23(2011) 1661
P. Vincent,A connection between score matching and denoising autoencoders,Neural Computation23(2011) 1661. 50
2011
-
[62]
Anderson,Reverse-time diffusion equation models,Stochastic Processes and their Applications12(1982) 313
B.D.O. Anderson,Reverse-time diffusion equation models,Stochastic Processes and their Applications12(1982) 313
1982
-
[63]
Kidger,On Neural Differential Equations, Ph.D
P. Kidger,On Neural Differential Equations, Ph.D. thesis, University of Oxford, 2021
2021
-
[64]
A. Tong, K. Fatras, N. Malkin, G. Huguet, Y . Zhang, J. Rector-Brooks et al.,Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport, 2302.00482
work page internal anchor Pith review Pith/arXiv arXiv
-
[65]
H.K. Cheng and A. Schwing,The Curse of Conditions: Analyzing and Improving Optimal Transport for Conditional Flow-Based Generation,2503.10636
-
[66]
X. Liu, C. Gong and Q. Liu,Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow,2209.03003
work page internal anchor Pith review Pith/arXiv arXiv
-
[67]
S. Singh and I. Fischer,Stochastic Sampling from Deterministic Flow Models,2410.02217
-
[68]
Scalable Diffusion Models with Transformers
W. Peebles and S. Xie,Scalable Diffusion Models with Transformers, inIEEE/CVF International Conference on Computer Vision, 2023, DOI [2212.09748]
work page internal anchor Pith review Pith/arXiv arXiv 2023
-
[69]
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
A. Dosovitskiy, L. Beyer, A. Kolesnikov, D. Weissenborn, X. Zhai, T. Unterthiner et al.,An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale,2010.11929
work page internal anchor Pith review Pith/arXiv arXiv 2010
-
[70]
GenSBI-examples: Benchmark examples for GenSBI
A. Amerio, “GenSBI-examples: Benchmark examples for GenSBI.” https://github.com/aurelio-amerio/GenSBI-examples, 2025
2025
-
[71]
Hutchinson,A Stochastic Estimator of the Trace of the Influence Matrix for Laplacian Smoothing Splines,Communications in Statistics – Simulation and Computation18(1989) 1059
M.F. Hutchinson,A Stochastic Estimator of the Trace of the Influence Matrix for Laplacian Smoothing Splines,Communications in Statistics – Simulation and Computation18(1989) 1059
1989
- [72]
- [73]
-
[74]
J. Linhart, A. Gramfort and P.L.C. Rodrigues,L-C2ST: Local Diagnostics for Posterior Approximations in Simulation-Based Inference, inAdvances in Neural Information Processing Systems, vol. 36, 2023, DOI [2306.03580]
-
[75]
Orbax: Training checkpointing and persistence utilities for JAX
Google DeepMind, “Orbax: Training checkpointing and persistence utilities for JAX.” https://github.com/google/orbax, 2024
2024
-
[76]
A. Nitz, I. Harry, D. Brown, C.M. Biwer, J. Willis, T.D. Canton et al., “Pycbc.” https://github.com/gwastro/pycbc, 2024. 10.5281/zenodo.10473621
-
[77]
N. Christensen and R. Meyer,Parameter estimation with gravitational waves,Reviews of Modern Physics94(2022) 025001 [2204.04449]
-
[78]
T. Treu,Strong Lensing by Galaxies,Annual Review of Astronomy and Astrophysics48(2010) 87 [1003.5567]
work page internal anchor Pith review Pith/arXiv arXiv 2010
-
[79]
Z. Geng, M. Deng, X. Bai, J.Z. Kolter and K. He,Mean Flows for One-step Generative Modeling,2505.13447
work page internal anchor Pith review Pith/arXiv arXiv
-
[80]
Revisiting Classifier Two-Sample Tests
D. Lopez-Paz and M. Oquab,Revisiting Classifier Two-Sample Tests,1610.06545
work page internal anchor Pith review Pith/arXiv arXiv
discussion (0)
Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.