Deep Kuratowski Embedding Neural Networks for Wasserstein Metric Learning
Pith reviewed 2026-05-10 20:18 UTC · model grok-4.3
The pith
Neural architectures inspired by the Kuratowski embedding theorem can learn accurate approximations to the Wasserstein-2 distance.
A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.
Core claim
By realizing a Kuratowski-style embedding of the Wasserstein metric either through weighted CNN feature distances or through Neural ODE trajectories in C^1([0,1], R^d), the networks produce surrogate distances whose squared error on held-out MNIST pairs is reduced by 28 percent relative to a single-layer baseline and 18 percent relative to the discrete-layer variant, while the continuous embedding also narrows the train-test gap.
What carries the argument
Kuratowski embedding of the Wasserstein metric, realized either by learnable aggregation of CNN feature distances or by continuous Neural ODE trajectories that map inputs into a space of smooth functions.
If this is right
- The trained networks can serve as drop-in replacements for exact W2 oracles inside pairwise distance computations.
- ODE-KENN yields both lower error and smaller generalization gap than discrete-layer models at matched parameter count.
- Trajectory smoothness supplies implicit regularization that improves out-of-sample performance.
- The resulting fast surrogate enables scaling of Wasserstein-based analyses to larger datasets.
Where Pith is reading between the lines
- The same embedding strategy might be tested on other optimal-transport costs or on non-image data modalities.
- Extending the ODE embedding to variable-length trajectories could connect to sequence or graph metric learning.
- Direct comparison against existing fast approximations such as Sinkhorn iterations would clarify relative speed-accuracy trade-offs.
- The continuous-function-space view may suggest regularizers for other metric-approximation tasks beyond Wasserstein distance.
Load-bearing premise
The neural embeddings preserve enough of the Wasserstein geometry to generalize from the MNIST training distribution to new samples.
What would settle it
Retraining both architectures on a shifted distribution such as CIFAR-10 and finding that test MSE on held-out pairs no longer improves over the single-layer baseline would falsify the claim of a useful generalizable embedding.
Figures
read the original abstract
Computing pairwise Wasserstein distances is a fundamental bottleneck in data analysis pipelines. Motivated by the classical Kuratowski embedding theorem, we propose two neural architectures for learning to approximate the Wasserstein-2 distance ($W_2$) from data. The first, DeepKENN, aggregates distances across all intermediate feature maps of a CNN using learnable positive weights. The second, ODE-KENN, replaces the discrete layer stack with a Neural ODE, embedding each input into the infinite-dimensional Banach space $C^1([0,1], \mathbb{R}^d)$ and providing implicit regularization via trajectory smoothness. Experiments on MNIST with exact precomputed $W_2$ distances show that ODE-KENN achieves a 28% lower test MSE than the single-layer baseline and 18% lower than DeepKENN under matched parameter counts, while exhibiting a smaller generalization gap. The resulting fast surrogate can replace the expensive $W_2$ oracle in downstream pairwise distance computations.
Editorial analysis
A structured set of objections, weighed in public.
Referee Report
Summary. The paper proposes two neural architectures, DeepKENN and ODE-KENN, motivated by the Kuratowski embedding theorem, to learn approximations to the Wasserstein-2 distance. DeepKENN aggregates distances across CNN feature maps using learnable positive weights, while ODE-KENN replaces the discrete layers with a Neural ODE to embed inputs into the function space C^1([0,1], R^d) with implicit regularization from trajectory smoothness. On MNIST using precomputed exact W_2 targets, the manuscript reports that ODE-KENN achieves 28% lower test MSE than a single-layer baseline and 18% lower than DeepKENN under matched parameter counts, along with a smaller generalization gap. The resulting models are positioned as fast surrogates for expensive W_2 oracles in downstream pairwise distance tasks.
Significance. If the reported MSE reductions hold under scrutiny and the architectures provide embeddings that meaningfully approximate the Wasserstein metric, the work supplies a practical tool for bypassing the computational cost of W_2 calculations in metric learning and optimal transport pipelines. The controlled comparison with matched parameter counts and the introduction of Neural ODE trajectories for continuous embeddings are concrete strengths that could influence scalable implementations of Wasserstein-based methods. The empirical focus on MNIST with exact targets offers a clear testbed, though broader significance hinges on generalization and verification of the embedding properties.
major comments (2)
- [§3.2] §3.2 (ODE-KENN construction): the claim that the Neural ODE trajectory realizes a Kuratowski-style embedding into C^1([0,1], R^d) that approximates W_2 is load-bearing for the paper's motivation, yet the manuscript provides no analysis or numerical check that the learned map satisfies the necessary isometry or distance-preservation properties of the classical theorem; without this, the performance gains could be explained by standard regression rather than the embedding framework.
- [§4] §4 (Experiments and results): the headline 28% and 18% test-MSE reductions are presented without accompanying details on the precise loss formulation, optimization schedule, data splits for the precomputed W_2 targets, or statistical significance across multiple random seeds; this information is required to assess whether the smaller generalization gap is robust or sensitive to unreported implementation choices.
minor comments (2)
- [Abstract] The abstract and §4 should explicitly state the solver or algorithm used to precompute the exact W_2 targets on MNIST, as this affects reproducibility of the regression targets.
- [§3.1] Notation for the positive weights in the DeepKENN aggregation step could be formalized with an equation that shows the positivity constraint and how they are initialized or regularized.
Simulated Author's Rebuttal
We thank the referee for the constructive and detailed comments, which have helped us identify areas for clarification and improvement. We address each major comment point by point below and will revise the manuscript to incorporate additional analysis and experimental details as outlined.
read point-by-point responses
-
Referee: [§3.2] §3.2 (ODE-KENN construction): the claim that the Neural ODE trajectory realizes a Kuratowski-style embedding into C^1([0,1], R^d) that approximates W_2 is load-bearing for the paper's motivation, yet the manuscript provides no analysis or numerical check that the learned map satisfies the necessary isometry or distance-preservation properties of the classical theorem; without this, the performance gains could be explained by standard regression rather than the embedding framework.
Authors: We agree that explicit verification of distance-preservation properties would strengthen the link to the Kuratowski embedding theorem. The ODE-KENN architecture is explicitly motivated by the theorem's construction of an isometric embedding into a Banach space of continuous functions, with the Neural ODE providing a continuous trajectory in C^1([0,1], R^d) and implicit regularization from smoothness. We do not claim that the learned map achieves exact isometry, as the network is trained to approximate W_2 via regression; rather, the embedding framework supplies the inductive bias for the architecture. To directly address the concern, the revised manuscript will add a new subsection in §3.2 (or §4) reporting numerical checks on a held-out set: specifically, the Pearson correlation between pairwise distances computed from the embedded trajectories (using the appropriate norm on C^1) and the ground-truth W_2 values, as well as a comparison of this correlation against a generic MLP regressor with matched capacity. These checks will demonstrate that the learned embedding preserves distances beyond what would be expected from unstructured regression, consistent with the smaller generalization gap observed under matched parameter counts. revision: partial
-
Referee: [§4] §4 (Experiments and results): the headline 28% and 18% test-MSE reductions are presented without accompanying details on the precise loss formulation, optimization schedule, data splits for the precomputed W_2 targets, or statistical significance across multiple random seeds; this information is required to assess whether the smaller generalization gap is robust or sensitive to unreported implementation choices.
Authors: We acknowledge that these details are necessary for reproducibility and for evaluating the robustness of the reported improvements. In the revised manuscript, Section 4 will be expanded to include: the precise loss (mean squared error between predicted and precomputed W_2 targets), the full optimization schedule (optimizer, learning rate schedule, number of epochs, and any regularization), the data split protocol for generating and partitioning the precomputed W_2 pairs (including train/validation/test ratios and how pairs were sampled), and results aggregated over multiple independent random seeds (with means and standard deviations) to establish statistical significance of the 28% and 18% MSE reductions as well as the generalization gap. revision: yes
Circularity Check
No significant circularity in derivation chain
full rationale
The paper motivates its architectures (DeepKENN and ODE-KENN) directly from the classical external Kuratowski embedding theorem and evaluates them via standard supervised regression on precomputed W2 targets from MNIST. No load-bearing step reduces a claimed prediction to a fitted input by construction, no self-citation chain supports the central premise, and the reported MSE improvements are ordinary empirical outcomes rather than tautological renamings or ansatzes smuggled via prior author work. The derivation remains self-contained against external benchmarks.
Axiom & Free-Parameter Ledger
free parameters (1)
- learnable positive weights for feature-map aggregation
axioms (1)
- standard math Kuratowski embedding theorem applies to the Wasserstein metric space
Reference graph
Works this paper leans on
-
[1]
Nicolas Bonneel, Michiel Van De Panne, Sylvain Paris, and Wolfgang Heidrich. Displace- ment interpolation using lagrangian mass transport.ACM Transactions on Graphics, 30(6):158, 2011
work page 2011
-
[2]
Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. Neural ordinary differential equations. InAdvances in Neural Information Processing Systems, volume 31, 2018
work page 2018
-
[3]
Alexander Cloninger, Keaton Hamm, Varun Khurana, and Caroline Moosm¨ uller. Lin- earized Wasserstein dimensionality reduction with approximation guarantees.Applied and Computational Harmonic Analysis, 74:101718, 2025
work page 2025
-
[4]
Sinkhorn distances: Lightspeed computation of optimal transport
Marco Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. In Advances in Neural Information Processing Systems, volume 26, 2013
work page 2013
-
[5]
R´ emi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aur´ elie Bois- bunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, L´ eo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotoma- monjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Suther- land, Romain Tavenard, Ale...
work page 2021
-
[6]
Keaton Hamm, Nick Henscheid, and Shujie Kang. Wassmap: Wasserstein isometric mapping for image manifold learning.SIAM Journal on Mathematics of Data Science, 5(2):475–501, 2023
work page 2023
-
[7]
Perceptual losses for real-time style transfer and super-resolution
Justin Johnson, Alexandre Alahi, and Li Fei-Fei. Perceptual losses for real-time style transfer and super-resolution. InEuropean Conference on Computer Vision, pages 694–
-
[8]
Johnson and Joram Lindenstrauss
William B. Johnson and Joram Lindenstrauss. Extensions of Lipschitz mappings into a Hilbert space. InConference in Modern Analysis and Probability, volume 26 ofCon- temporary Mathematics, pages 189–206. American Mathematical Society, 1984
work page 1984
-
[9]
Yann LeCun, L´ eon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learn- ing applied to document recognition.Proceedings of the IEEE, 86(11):2278–2324, 1998
work page 1998
-
[10]
Gabriel Peyr´ e and Marco Cuturi.Computational Optimal Transport, volume 11. 2019
work page 2019
-
[11]
C´ edric Villani.Optimal Transport: Old and New. Springer, 2008. A Architecture Details Table 2: CNN encoder architecture shared by all three models. Layer Operation Output shape Flat dim Conv1 Conv(1→8,5×5) + ReLU + MaxPool(2) (8,14,14) 1568 Conv2 Conv(8→16,3×3) + ReLU + MaxPool(2) (16,7,7) 784 Conv3 Conv(16→32,3×3) + ReLU + MaxPool(2) (32,3,3) 288 FC1 L...
work page 2008
discussion (0)
Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.