FlashSinkhorn: IO-Aware Entropic Optimal Transport on GPU
Pith reviewed 2026-05-22 11:26 UTC · model grok-4.3
The pith
FlashSinkhorn rewrites Sinkhorn iterations as fused row-wise LogSumExp reductions on biased dot products to cut GPU memory traffic for entropic optimal transport.
A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.
Core claim
Stabilized Sinkhorn updates for squared Euclidean costs can be expressed as row-wise LogSumExp reductions of biased dot-product scores, enabling FlashAttention-style fused tiling in Triton kernels that stream tiles through SRAM, update dual potentials in one pass, and substantially lower HBM IO per iteration while retaining linear memory and numerical stability.
What carries the argument
Fused Triton kernels that perform row-wise LogSumExp reductions on biased dot-product scores with FlashAttention-style tiling and SRAM streaming to update dual potentials in a single pass.
If this is right
- Up to 32 times faster forward passes than state-of-the-art online baselines on A100 GPUs for point-cloud optimal transport.
- Up to 161 times faster end-to-end optimization when the transport plan is used inside first- or second-order methods.
- Improved scalability for downstream tasks that repeatedly solve optimal transport problems.
- Linear-memory streaming kernels for applying the transport plan without materializing dense matrices.
Where Pith is reading between the lines
- The same IO-fusion pattern could be applied to other iterative algorithms whose inner steps resemble attention or matrix-normalization operations.
- Extending the kernels to additional cost functions would require only a change in the score computation while keeping the LogSumExp tiling intact.
- Integration into larger machine-learning pipelines could reduce the wall-clock cost of using entropic OT as a regularizer or loss term.
Load-bearing premise
The cost matrix must be squared Euclidean distance so the kernel becomes a biased dot product, and the chosen tiling must preserve numerical stability of the log-domain updates without extra safeguards.
What would settle it
Measure forward-pass time and solution accuracy on a non-Euclidean cost matrix using the same fused kernels and check whether speedups vanish or accuracy degrades relative to a standard implementation.
Figures
read the original abstract
Entropic optimal transport (EOT) via Sinkhorn iterations is widely used in modern machine learning, yet GPU solvers remain inefficient at scale. Tensorized implementations suffer quadratic HBM traffic from dense $n\times m$ interactions, while existing online backends avoid storing dense matrices but still rely on generic tiled map-reduce reduction kernels with limited fusion. We present \textbf{FlashSinkhorn}, an IO-aware EOT solver for squared Euclidean cost that rewrites stabilized log-domain Sinkhorn updates as row-wise LogSumExp reductions of biased dot-product scores, the same normalization as transformer attention. This enables FlashAttention-style fusion and tiling: fused Triton kernels stream tiles through on-chip SRAM and update dual potentials in a single pass, substantially reducing HBM IO per iteration while retaining linear-memory operations. We further provide streaming kernels for transport application, enabling scalable first- and second-order optimization. On A100 GPUs, FlashSinkhorn achieves up to $32\times$ forward-pass and $161\times$ end-to-end speedups over state-of-the-art online baselines on point-cloud OT, improves scalability on OT-based downstream tasks. For reproducibility, we release an open-source implementation at https://github.com/ot-triton-lab/flash-sinkhorn .
Editorial analysis
A structured set of objections, weighed in public.
Referee Report
Summary. The paper presents FlashSinkhorn, an IO-aware GPU implementation of entropic optimal transport via Sinkhorn iterations for squared Euclidean costs. It rewrites the stabilized log-domain updates as row-wise LogSumExp reductions over biased dot-product scores (analogous to transformer attention), enabling FlashAttention-style fusion and tiling in Triton kernels that stream tiles through SRAM to reduce HBM traffic while keeping linear memory. Reported results include up to 32× forward-pass and 161× end-to-end speedups on A100 GPUs versus state-of-the-art online baselines for point-cloud OT, plus improved scalability for downstream OT tasks. The authors release open-source code.
Significance. If numerical equivalence to reference Sinkhorn is preserved, the work offers a practical advance in scaling OT computations in machine learning by cutting memory IO without sacrificing the core algorithm. The release of reproducible open-source Triton kernels is a clear strength that supports verification and adoption.
major comments (2)
- [§3] §3: The fused Triton kernels are stated to implement the stabilized log-domain Sinkhorn recurrence via per-tile LogSumExp on biased dot products, but the manuscript provides neither a proof of equivalence nor quantitative accuracy tables (e.g., marginal violation or OT-cost error versus POT baseline) at the largest reported problem sizes. This is load-bearing for the speedup claims, because any drift in the fixed-point dual potentials would render the performance numbers incomparable to the baselines.
- [Abstract] Abstract and experimental section: Speedup figures are presented without error bars, without ablation on numerical precision across tile sizes, and without explicit verification that the chosen fusion and max-subtraction strategy preserves stability under the reported scales.
minor comments (2)
- Clarify the precise scope of the squared-Euclidean assumption and whether the kernel design extends to other cost functions.
- Add a short table or figure panel reporting marginal violation and OT-cost error versus a reference implementation across the range of n and m tested.
Simulated Author's Rebuttal
We thank the referee for the constructive feedback highlighting the importance of numerical validation. We agree that stronger evidence of equivalence and stability is needed to support the speedup claims and will revise the manuscript accordingly.
read point-by-point responses
-
Referee: [§3] §3: The fused Triton kernels are stated to implement the stabilized log-domain Sinkhorn recurrence via per-tile LogSumExp on biased dot products, but the manuscript provides neither a proof of equivalence nor quantitative accuracy tables (e.g., marginal violation or OT-cost error versus POT baseline) at the largest reported problem sizes. This is load-bearing for the speedup claims, because any drift in the fixed-point dual potentials would render the performance numbers incomparable to the baselines.
Authors: We acknowledge that the current manuscript lacks an explicit proof and quantitative accuracy tables at the largest scales. The per-tile LogSumExp with max-subtraction and bias is mathematically equivalent to the standard stabilized log-domain Sinkhorn because it performs identical row-wise normalization on the same biased scores. In the revised version we will add a short proof of equivalence in §3 and a new table in the experiments reporting marginal violation (L1 error on marginals) and OT-cost error versus the POT baseline for all reported problem sizes up to the largest n. These additions will confirm that dual potentials match within floating-point tolerance and that the reported speedups are directly comparable. revision: yes
-
Referee: [Abstract] Abstract and experimental section: Speedup figures are presented without error bars, without ablation on numerical precision across tile sizes, and without explicit verification that the chosen fusion and max-subtraction strategy preserves stability under the reported scales.
Authors: We agree that error bars, precision ablations, and explicit stability checks are missing. We will update the abstract and experimental sections to report mean speedups with standard deviation over five independent runs. We will also add an ablation subsection varying tile sizes (e.g., 32×32 to 128×128) and precision (FP32 vs. FP16), together with convergence plots showing that the max-subtraction strategy reaches the same fixed-point dual potentials as the reference implementation at all reported scales. These changes will directly address concerns about numerical reliability. revision: yes
Circularity Check
No circularity: empirical implementation with external validation
full rationale
The paper describes an engineering contribution that rewrites existing stabilized Sinkhorn iterations into a FlashAttention-compatible form to enable kernel fusion and reduced HBM traffic. All performance numbers are measured speedups against independent external baselines (POT, online Sinkhorn implementations) on concrete hardware. No first-principles derivation, fitted parameter, or self-citation chain is invoked to produce the claimed results; the mathematical equivalence between log-domain Sinkhorn and row-wise LogSumExp on biased dot products is a standard algebraic identity, not a reduction to the paper's own outputs. The work is therefore self-contained against external benchmarks.
Axiom & Free-Parameter Ledger
axioms (1)
- standard math Stabilized log-domain Sinkhorn iterations converge to the same fixed point as the original multiplicative updates for positive costs.
Forward citations
Cited by 2 Pith papers
-
ASAP: Amortized Doubly-Stochastic Attention via Sliced Dual Projection
ASAP amortizes Sinkhorn-based doubly-stochastic attention by learning a parametric map from 1D potentials to the Sinkhorn dual and reconstructing the plan via two-sided entropic c-transform, delivering 5.3x faster inf...
-
Spherical Harmonic Optimal Transport: Application to Climate Models Comparisons
Heat kernel Sinkhorn algorithm on the 2-sphere converges to OT cost with O(n) memory and O(n^{3/2}) time per iteration, retaining geometric properties and applied to climate model evaluation.
Reference graph
Works this paper leans on
-
[1]
Stochastic EM for Shuffled Linear Regression
Abubakar Abid and James Zou. Stochastic em for shuffled linear regression.arXiv preprint arXiv:1804.00681, 2018
work page internal anchor Pith review Pith/arXiv arXiv 2018
-
[2]
Jason Altschuler, Jonathan Niles-Weed, and Philippe Rigollet. Near-linear time approxima- tion algorithms for optimal transport via sinkhorn iteration.Advances in neural information processing systems, 30, 2017
work page 2017
-
[3]
Geometric dataset distances via optimal transport
David Alvarez-Melis and Nicolo Fusi. Geometric dataset distances via optimal transport. Advances in Neural Information Processing Systems, 33:21428–21439, 2020
work page 2020
-
[4]
Dataset dynamics via gradient flows in probability space
David Alvarez-Melis and Nicol` o Fusi. Dataset dynamics via gradient flows in probability space. InInternational conference on machine learning, pages 219–230. PMLR, 2021
work page 2021
-
[5]
Wasserstein generative adversarial networks
Martin Arjovsky, Soumith Chintala, and L´ eon Bottou. Wasserstein generative adversarial networks. InInternational Conference on Machine Learning, pages 214–223, 2017
work page 2017
-
[6]
Austin R Benson, Jason D Lee, Bartek Rajwa, and David F Gleich. Scalable methods for nonnegative matrix factorizations of near-separable tall-and-skinny matrices.Advances in neural information processing systems, 27, 2014. 13
work page 2014
-
[7]
Nonlinear programming.Journal of the Operational Research Society, 48(3):334–334, 1997
Dimitri P Bertsekas. Nonlinear programming.Journal of the Operational Research Society, 48(3):334–334, 1997
work page 1997
-
[8]
A Sinkhorn-Newton method for entropic optimal transport
Christoph Brauer, Christian Clason, Dirk Lorenz, and Benedikt Wirth. A sinkhorn-newton method for entropic optimal transport.arXiv preprint arXiv:1710.06635, 2017
work page internal anchor Pith review Pith/arXiv arXiv 2017
-
[9]
Optimal transport for single-cell and spatial omics.Nature Reviews Methods Primers, 4(1):58, 2024
Charlotte Bunne, Geoffrey Schiebinger, Andreas Krause, Aviv Regev, and Marco Cuturi. Optimal transport for single-cell and spatial omics.Nature Reviews Methods Primers, 4(1):58, 2024
work page 2024
-
[10]
Benjamin Charlier, Jean Feydy, Joan Alexis Glaunes, Fran¸ cois-David Collin, and Ghislain Durif. Kernel operations on the gpu, with autodiff, without memory overflows.Journal of Machine Learning Research, 22(74):1–6, 2021
work page 2021
-
[11]
Sinkhorn distances: Lightspeed computation of optimal transport
Marco Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. InAdvances in Neural Information Processing Systems, volume 26, 2013
work page 2013
-
[12]
Fast computation of Wasserstein barycenters
Marco Cuturi and Arnaud Doucet. Fast computation of Wasserstein barycenters. InInterna- tional Conference on Machine Learning, pages 685–693, 2014
work page 2014
-
[13]
Optimal transport tools (OTT): A JAX toolbox for all things Wasserstein
Marco Cuturi, Laetitia Meng-Papaxanthos, Yingtao Tian, Charlotte Bunne, Geoff Davis, and Olivier Teboul. Optimal transport tools (OTT): A JAX toolbox for all things Wasserstein. arXiv preprint arXiv:2201.12324, 2022
-
[14]
Escaping saddles with stochastic gradients
Hadi Daneshmand, Jonas Kohler, Aurelien Lucchi, and Thomas Hofmann. Escaping saddles with stochastic gradients. InInternational Conference on Machine Learning, pages 1155–1164. PMLR, 2018
work page 2018
-
[15]
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
Tri Dao. FlashAttention-2: Faster attention with better parallelism and work partitioning. arXiv preprint arXiv:2307.08691, 2023
work page internal anchor Pith review Pith/arXiv arXiv 2023
-
[16]
FlashAttention: Fast and memory-efficient exact attention with IO-awareness
Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher R´ e. FlashAttention: Fast and memory-efficient exact attention with IO-awareness. InAdvances in Neural Information Processing Systems, volume 35, pages 16344–16359, 2022
work page 2022
-
[17]
Li Deng. The mnist database of handwritten digit images for machine learning research [best of the web].IEEE signal processing magazine, 29(6):141–142, 2012
work page 2012
-
[18]
A unified framework for implicit Sinkhorn differentiation
Marvin Eisenberger, Aysim Toker, Laura Leal-Taix´ e, and Daniel Cremers. A unified framework for implicit Sinkhorn differentiation. InIEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 509–518, 2022
work page 2022
-
[19]
Geometric loss functions between sampled measures, images and volumes.URL https://www
Jean Feydy. Geometric loss functions between sampled measures, images and volumes.URL https://www. kernel-operations. io/geomloss, 2019
work page 2019
-
[20]
Geometric data analysis, beyond convolutions.Applied Mathematics, 3, 2020
Jean Feydy. Geometric data analysis, beyond convolutions.Applied Mathematics, 3, 2020
work page 2020
-
[21]
Interpolating between optimal transport and MMD using Sinkhorn divergences
Jean Feydy, Thibault S´ ejourn´ e, Fran¸ cois-Xavier Vialard, Shun-ichi Amari, Alain Trouve, and Gabriel Peyr´ e. Interpolating between optimal transport and MMD using Sinkhorn divergences. InInternational Conference on Artificial Intelligence and Statistics, pages 2681–2690, 2019
work page 2019
-
[22]
POT: Python optimal transport.Journal of Machine Learning Research, 22(78):1–8, 2021
R´ emi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z Alaya, Aur´ elie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, et al. POT: Python optimal transport.Journal of Machine Learning Research, 22(78):1–8, 2021. 14
work page 2021
-
[23]
Escaping from saddle points—online stochastic gradient for tensor decomposition
Rong Ge, Furong Huang, Chi Jin, and Yang Yuan. Escaping from saddle points—online stochastic gradient for tensor decomposition. InConference on learning theory, pages 797–
-
[24]
Gene H Golub and Charles F Van Loan.Matrix computations. JHU press, 2013
work page 2013
-
[25]
How to escape saddle points efficiently
Chi Jin, Rong Ge, Praneeth Netrapalli, Sham M Kakade, and Michael I Jordan. How to escape saddle points efficiently. InInternational conference on machine learning, pages 1724–1732. PMLR, 2017
work page 2017
-
[26]
Dmitry Kangin and Plamen Angelov. Unsupervised domain adaptation within deep foundation latent spaces.arXiv preprint arXiv:2402.14976, 2024
-
[27]
A truncated new- ton method for optimal transport
Mete Kemertas, Amir-massoud Farahmand, and Allan Douglas Jepson. A truncated new- ton method for optimal transport. InThe Thirteenth International Conference on Learning Representations, 2025
work page 2025
-
[28]
Alex Kokot and Alex Luedtke. Coreset selection for the sinkhorn divergence and generic smooth divergences.arXiv preprint arXiv:2504.20194, 2025
-
[29]
Haonan Li, Keyu Man, Partha Kanuparthy, Hanning Chen, Wei Sun, Sreen Tallam, Chenguang Zhu, Kevin Zhu, and Zhiyun Qian. Tritonforge: Profiling-guided framework for automated triton kernel optimization.arXiv preprint arXiv:2512.09196, 2025
-
[30]
Xingjie Li, Fei Lu, Molei Tao, and Felix X-F Ye. Robust first-and second-order differentiation for regularized optimal transport.SIAM Journal on Scientific Computing, 47(3):C630–C654, 2025
work page 2025
-
[31]
Jacob Lindb¨ ack, Zesen Wang, and Mikael Johansson. Bringing regularized optimal transport to lightspeed: a splitting method adapted for gpus.Advances in Neural Information Processing Systems, 36:26845–26871, 2023
work page 2023
-
[32]
Elon Litman. Scaled-dot-product attention as one-sided entropic optimal transport.arXiv preprint arXiv:2508.08369, 2025
-
[33]
You need better attention priors.arXiv preprint arXiv:2601.15380, 2026
Elon Litman and Gabe Guo. You need better attention priors.arXiv preprint arXiv:2601.15380, 2026
-
[34]
Deep learning via hessian-free optimization
James Martens et al. Deep learning via hessian-free optimization. InIcml, volume 27, pages 735–742, 2010
work page 2010
-
[35]
Learning latent permu- tations with gumbel-sinkhorn networks
Gonzalo Mena, David Belanger, Scott Linderman, and Jasper Snoek. Learning latent permu- tations with gumbel-sinkhorn networks. InInternational Conference on Learning Representa- tions, 2018
work page 2018
-
[36]
Arthur Mensch and Gabriel Peyr´ e. Online sinkhorn: Optimal transport distances from sample streams.Advances in Neural Information Processing Systems, 33:1657–1667, 2020
work page 2020
-
[37]
Online normalizer calculation for softmax
Maxim Milakov and Natalia Gimelshein. Online normalizer calculation for softmax.arXiv preprint arXiv:1805.02867, 2018
work page internal anchor Pith review Pith/arXiv arXiv 2018
-
[38]
Computational optimal transport.Foundations and Trends in Machine Learning, 11(5-6):355–607, 2019
Gabriel Peyr´ e and Marco Cuturi. Computational optimal transport.Foundations and Trends in Machine Learning, 11(5-6):355–607, 2019. 15
work page 2019
-
[39]
Entropic estimation of optimal trans- port maps.arXiv preprint arXiv:2109.12004, 2021
Aram-Alexandre Pooladian and Jonathan Niles-Weed. Entropic estimation of optimal trans- port maps.arXiv preprint arXiv:2109.12004, 2021
-
[40]
The anatomy of a triton attention kernel.arXiv preprint arXiv:2511.11581, 2025
Burkhard Ringlein, Jan van Lunteren, Radu Stoica, and Thomas Parnell. The anatomy of a triton attention kernel.arXiv preprint arXiv:2511.11581, 2025
-
[41]
Sinkformers: Trans- formers with doubly stochastic attention
Michael E Sander, Pierre Ablin, Mathieu Blondel, and Gabriel Peyr´ e. Sinkformers: Trans- formers with doubly stochastic attention. InInternational Conference on Artificial Intelligence and Statistics, pages 3515–3530. PMLR, 2022
work page 2022
-
[42]
Geoffrey Schiebinger, Jian Shu, Marcin Tabaka, Brian Cleary, Vidya Subramanian, Aryeh Solomon, Joshua Gould, Siyan Liu, Stacie Lin, Peter Berber, et al. Optimal-transport analysis of single-cell gene expression identifies developmental trajectories in reprogramming.Cell, 176(4):928–943, 2019
work page 2019
-
[43]
Bernhard Schmitzer. Stabilized sparse scaling algorithms for entropy regularized transport problems.SIAM Journal on Scientific Computing, 41(3):A1443–A1481, 2019
work page 2019
-
[44]
Zebang Shen, Zhenfu Wang, Alejandro Ribeiro, and Hamed Hassani. Sinkhorn natural gradient for generative models.Advances in Neural Information Processing Systems, 33:1646–1656, 2020
work page 2020
-
[45]
Justin Solomon, Fernando De Goes, Gabriel Peyr´ e, Marco Cuturi, Adrian Butscher, Andy Nguyen, Tao Du, and Leonidas Guibas. Convolutional Wasserstein distances: Efficient optimal transportation on geometric domains.ACM Transactions on Graphics, 34(4):1–11, 2015
work page 2015
-
[46]
Zihao Tang and Yixuan Qiu. Safe and sparse newton method for entropic-regularized optimal transport.Advances in Neural Information Processing Systems, 37:129914–129943, 2024
work page 2024
-
[47]
Yi Tay, Dara Bahri, Liu Yang, Donald Metzler, and Da-Cheng Juan. Sparse sinkhorn attention. InInternational conference on machine learning, pages 9438–9447. PMLR, 2020
work page 2020
-
[48]
Numerical optimization.Springer Science, 35(67-68):7, 1999
Stephen Wright, Jorge Nocedal, et al. Numerical optimization.Springer Science, 35(67-68):7, 1999
work page 1999
-
[49]
Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms
Han Xiao, Kashif Rasul, and Roland Vollgraf. Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms.arXiv preprint arXiv:1708.07747, 2017
work page internal anchor Pith review Pith/arXiv arXiv 2017
-
[50]
Yujia Xie, Yixiu Mao, Simiao Zuo, Hongteng Xu, Xiaojing Ye, Tuo Zhao, and Hongyuan Zha. A hypergradient approach to robust regression without correspondence.arXiv preprint arXiv:2012.00123, 2020
-
[51]
Naoya Yamamoto, Juno Kim, and Taiji Suzuki. Hessian-guided perturbed wasserstein gradient flows for escaping saddle points.arXiv preprint arXiv:2509.16974, 2025. 16 A Notation Table 4: Notation used throughout the paper. Symbol Meaning Discrete measures and cost X= [x 1;. . .;x n]∈R n×d Source points (rows),x i ∈R d. Y= [y 1;. . .;y m]∈R m×d Target points...
-
[52]
Points are sampled uniformly from [0,1] d withn∈ [5k,50k],d∈[4,1024] for forward/backward benchmarks andd∈[4,512] for HVP examples. All methods use regularizationε= 0.1 and 10 Sinkhorn iterations for forward/backward benchmarks, or 100 Sinkhorn iterations for HVP benchmarks. For HVP, the conjugate gradient solver uses K= 50 fixed iterations with dampingτ=...
-
[53]
is not benchmarked here since our evaluation targets GPU implementations. GeomLoss also providesbackend=’multiscale’, a fast octree-style multiscale routine intended for very large 31 Table 5: NCU profiling of the forward pass (n=m=10000,d=64, 10 Sinkhorn iterations, A100- 80GB).Note:Working set (5 MB) fits in A100 L2 cache (40 MB); HBM traffic reflects L...
-
[54]
Ground truth transformation:W ∗ ∈R 5×5 with entriesW ∗ ij ∼ N(0,1/5)
-
[55]
Clean targets:Y clean =XW ∗
-
[56]
Noisy targets:Y=Y clean +E, whereE ij ∼ N(0, σ 2) withσ= 0.05·std(Y clean)
-
[57]
Shuffled observations: eY= Π ∗(Y) for unknown permutation Π ∗. The optimization objective is min W L(W), where L(W) = OT ε 1 n nX i=1 δyi, 1 n nX j=1 δeyj = min P∈Π( 1 n 1 n, 1 n 1 n) ⟨C(W), P⟩+εKL P∥ 1 n n ⊗ 1 n n wherey i =x iWandC ij(W) =∥y i −eyj∥2 2 Optimizer ConfigurationWe test three regularization strengthsε∈0.1,0.25,0.5. The solver uses e...
discussion (0)
Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.