pith. sign in

arxiv: 2604.04343 · v1 · submitted 2026-04-06 · 💻 cs.LG

Deep Kuratowski Embedding Neural Networks for Wasserstein Metric Learning

Pith reviewed 2026-05-10 20:18 UTC · model grok-4.3

classification 💻 cs.LG
keywords Wasserstein distanceKuratowski embeddingneural ODEmetric learningoptimal transportdeep learningMNISTdistance approximation
0
0 comments X

The pith

Neural architectures inspired by the Kuratowski embedding theorem can learn accurate approximations to the Wasserstein-2 distance.

A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.

The paper introduces two neural networks to approximate the expensive Wasserstein-2 distance from data samples. DeepKENN aggregates distances from intermediate CNN feature maps using learned positive weights. ODE-KENN instead maps each input to a continuous trajectory in an infinite-dimensional function space via a neural ordinary differential equation, which adds smoothness-based regularization. On MNIST with exact precomputed targets, the ODE version delivers lower test error and a tighter generalization gap than both a single-layer model and the discrete DeepKENN under equal parameter budgets, opening the way for fast surrogate distances in larger pipelines.

Core claim

By realizing a Kuratowski-style embedding of the Wasserstein metric either through weighted CNN feature distances or through Neural ODE trajectories in C^1([0,1], R^d), the networks produce surrogate distances whose squared error on held-out MNIST pairs is reduced by 28 percent relative to a single-layer baseline and 18 percent relative to the discrete-layer variant, while the continuous embedding also narrows the train-test gap.

What carries the argument

Kuratowski embedding of the Wasserstein metric, realized either by learnable aggregation of CNN feature distances or by continuous Neural ODE trajectories that map inputs into a space of smooth functions.

If this is right

  • The trained networks can serve as drop-in replacements for exact W2 oracles inside pairwise distance computations.
  • ODE-KENN yields both lower error and smaller generalization gap than discrete-layer models at matched parameter count.
  • Trajectory smoothness supplies implicit regularization that improves out-of-sample performance.
  • The resulting fast surrogate enables scaling of Wasserstein-based analyses to larger datasets.

Where Pith is reading between the lines

These are editorial extensions of the paper, not claims the author makes directly.

  • The same embedding strategy might be tested on other optimal-transport costs or on non-image data modalities.
  • Extending the ODE embedding to variable-length trajectories could connect to sequence or graph metric learning.
  • Direct comparison against existing fast approximations such as Sinkhorn iterations would clarify relative speed-accuracy trade-offs.
  • The continuous-function-space view may suggest regularizers for other metric-approximation tasks beyond Wasserstein distance.

Load-bearing premise

The neural embeddings preserve enough of the Wasserstein geometry to generalize from the MNIST training distribution to new samples.

What would settle it

Retraining both architectures on a shifted distribution such as CIFAR-10 and finding that test MSE on held-out pairs no longer improves over the single-layer baseline would falsify the claim of a useful generalizable embedding.

Figures

Figures reproduced from arXiv: 2604.04343 by Andrew Qing He.

Figure 1
Figure 1. Figure 1: Experimental results for all three models trained for 2,000 epochs on MNIST [PITH_FULL_IMAGE:figures/full_fig_p007_1.png] view at source ↗
read the original abstract

Computing pairwise Wasserstein distances is a fundamental bottleneck in data analysis pipelines. Motivated by the classical Kuratowski embedding theorem, we propose two neural architectures for learning to approximate the Wasserstein-2 distance ($W_2$) from data. The first, DeepKENN, aggregates distances across all intermediate feature maps of a CNN using learnable positive weights. The second, ODE-KENN, replaces the discrete layer stack with a Neural ODE, embedding each input into the infinite-dimensional Banach space $C^1([0,1], \mathbb{R}^d)$ and providing implicit regularization via trajectory smoothness. Experiments on MNIST with exact precomputed $W_2$ distances show that ODE-KENN achieves a 28% lower test MSE than the single-layer baseline and 18% lower than DeepKENN under matched parameter counts, while exhibiting a smaller generalization gap. The resulting fast surrogate can replace the expensive $W_2$ oracle in downstream pairwise distance computations.

Editorial analysis

A structured set of objections, weighed in public.

Desk editor's note, referee report, simulated authors' rebuttal, and a circularity audit. Tearing a paper down is the easy half of reading it; the pith above is the substance, this is the friction.

Referee Report

2 major / 2 minor

Summary. The paper proposes two neural architectures, DeepKENN and ODE-KENN, motivated by the Kuratowski embedding theorem, to learn approximations to the Wasserstein-2 distance. DeepKENN aggregates distances across CNN feature maps using learnable positive weights, while ODE-KENN replaces the discrete layers with a Neural ODE to embed inputs into the function space C^1([0,1], R^d) with implicit regularization from trajectory smoothness. On MNIST using precomputed exact W_2 targets, the manuscript reports that ODE-KENN achieves 28% lower test MSE than a single-layer baseline and 18% lower than DeepKENN under matched parameter counts, along with a smaller generalization gap. The resulting models are positioned as fast surrogates for expensive W_2 oracles in downstream pairwise distance tasks.

Significance. If the reported MSE reductions hold under scrutiny and the architectures provide embeddings that meaningfully approximate the Wasserstein metric, the work supplies a practical tool for bypassing the computational cost of W_2 calculations in metric learning and optimal transport pipelines. The controlled comparison with matched parameter counts and the introduction of Neural ODE trajectories for continuous embeddings are concrete strengths that could influence scalable implementations of Wasserstein-based methods. The empirical focus on MNIST with exact targets offers a clear testbed, though broader significance hinges on generalization and verification of the embedding properties.

major comments (2)
  1. [§3.2] §3.2 (ODE-KENN construction): the claim that the Neural ODE trajectory realizes a Kuratowski-style embedding into C^1([0,1], R^d) that approximates W_2 is load-bearing for the paper's motivation, yet the manuscript provides no analysis or numerical check that the learned map satisfies the necessary isometry or distance-preservation properties of the classical theorem; without this, the performance gains could be explained by standard regression rather than the embedding framework.
  2. [§4] §4 (Experiments and results): the headline 28% and 18% test-MSE reductions are presented without accompanying details on the precise loss formulation, optimization schedule, data splits for the precomputed W_2 targets, or statistical significance across multiple random seeds; this information is required to assess whether the smaller generalization gap is robust or sensitive to unreported implementation choices.
minor comments (2)
  1. [Abstract] The abstract and §4 should explicitly state the solver or algorithm used to precompute the exact W_2 targets on MNIST, as this affects reproducibility of the regression targets.
  2. [§3.1] Notation for the positive weights in the DeepKENN aggregation step could be formalized with an equation that shows the positivity constraint and how they are initialized or regularized.

Simulated Author's Rebuttal

2 responses · 0 unresolved

We thank the referee for the constructive and detailed comments, which have helped us identify areas for clarification and improvement. We address each major comment point by point below and will revise the manuscript to incorporate additional analysis and experimental details as outlined.

read point-by-point responses
  1. Referee: [§3.2] §3.2 (ODE-KENN construction): the claim that the Neural ODE trajectory realizes a Kuratowski-style embedding into C^1([0,1], R^d) that approximates W_2 is load-bearing for the paper's motivation, yet the manuscript provides no analysis or numerical check that the learned map satisfies the necessary isometry or distance-preservation properties of the classical theorem; without this, the performance gains could be explained by standard regression rather than the embedding framework.

    Authors: We agree that explicit verification of distance-preservation properties would strengthen the link to the Kuratowski embedding theorem. The ODE-KENN architecture is explicitly motivated by the theorem's construction of an isometric embedding into a Banach space of continuous functions, with the Neural ODE providing a continuous trajectory in C^1([0,1], R^d) and implicit regularization from smoothness. We do not claim that the learned map achieves exact isometry, as the network is trained to approximate W_2 via regression; rather, the embedding framework supplies the inductive bias for the architecture. To directly address the concern, the revised manuscript will add a new subsection in §3.2 (or §4) reporting numerical checks on a held-out set: specifically, the Pearson correlation between pairwise distances computed from the embedded trajectories (using the appropriate norm on C^1) and the ground-truth W_2 values, as well as a comparison of this correlation against a generic MLP regressor with matched capacity. These checks will demonstrate that the learned embedding preserves distances beyond what would be expected from unstructured regression, consistent with the smaller generalization gap observed under matched parameter counts. revision: partial

  2. Referee: [§4] §4 (Experiments and results): the headline 28% and 18% test-MSE reductions are presented without accompanying details on the precise loss formulation, optimization schedule, data splits for the precomputed W_2 targets, or statistical significance across multiple random seeds; this information is required to assess whether the smaller generalization gap is robust or sensitive to unreported implementation choices.

    Authors: We acknowledge that these details are necessary for reproducibility and for evaluating the robustness of the reported improvements. In the revised manuscript, Section 4 will be expanded to include: the precise loss (mean squared error between predicted and precomputed W_2 targets), the full optimization schedule (optimizer, learning rate schedule, number of epochs, and any regularization), the data split protocol for generating and partitioning the precomputed W_2 pairs (including train/validation/test ratios and how pairs were sampled), and results aggregated over multiple independent random seeds (with means and standard deviations) to establish statistical significance of the 28% and 18% MSE reductions as well as the generalization gap. revision: yes

Circularity Check

0 steps flagged

No significant circularity in derivation chain

full rationale

The paper motivates its architectures (DeepKENN and ODE-KENN) directly from the classical external Kuratowski embedding theorem and evaluates them via standard supervised regression on precomputed W2 targets from MNIST. No load-bearing step reduces a claimed prediction to a fitted input by construction, no self-citation chain supports the central premise, and the reported MSE improvements are ordinary empirical outcomes rather than tautological renamings or ansatzes smuggled via prior author work. The derivation remains self-contained against external benchmarks.

Axiom & Free-Parameter Ledger

1 free parameters · 1 axioms · 0 invented entities

The central claim rests on the classical Kuratowski theorem plus several learnable parameters inside the neural networks; no new entities are postulated.

free parameters (1)
  • learnable positive weights for feature-map aggregation
    Introduced in DeepKENN to combine distances across CNN layers; values are fitted during training.
axioms (1)
  • standard math Kuratowski embedding theorem applies to the Wasserstein metric space
    Invoked in the motivation section to justify embedding inputs into a function space.

pith-pipeline@v0.9.0 · 5456 in / 1219 out tokens · 27918 ms · 2026-05-10T20:18:45.675130+00:00 · methodology

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Reference graph

Works this paper leans on

11 extracted references · 11 canonical work pages

  1. [1]

    Displace- ment interpolation using lagrangian mass transport.ACM Transactions on Graphics, 30(6):158, 2011

    Nicolas Bonneel, Michiel Van De Panne, Sylvain Paris, and Wolfgang Heidrich. Displace- ment interpolation using lagrangian mass transport.ACM Transactions on Graphics, 30(6):158, 2011

  2. [2]

    Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. Neural ordinary differential equations. InAdvances in Neural Information Processing Systems, volume 31, 2018

  3. [3]

    Lin- earized Wasserstein dimensionality reduction with approximation guarantees.Applied and Computational Harmonic Analysis, 74:101718, 2025

    Alexander Cloninger, Keaton Hamm, Varun Khurana, and Caroline Moosm¨ uller. Lin- earized Wasserstein dimensionality reduction with approximation guarantees.Applied and Computational Harmonic Analysis, 74:101718, 2025

  4. [4]

    Sinkhorn distances: Lightspeed computation of optimal transport

    Marco Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. In Advances in Neural Information Processing Systems, volume 26, 2013

  5. [5]

    Alaya, Aur´ elie Bois- bunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, L´ eo Gautheron, Nathalie T.H

    R´ emi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aur´ elie Bois- bunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, L´ eo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotoma- monjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Suther- land, Romain Tavenard, Ale...

  6. [6]

    Wassmap: Wasserstein isometric mapping for image manifold learning.SIAM Journal on Mathematics of Data Science, 5(2):475–501, 2023

    Keaton Hamm, Nick Henscheid, and Shujie Kang. Wassmap: Wasserstein isometric mapping for image manifold learning.SIAM Journal on Mathematics of Data Science, 5(2):475–501, 2023

  7. [7]

    Perceptual losses for real-time style transfer and super-resolution

    Justin Johnson, Alexandre Alahi, and Li Fei-Fei. Perceptual losses for real-time style transfer and super-resolution. InEuropean Conference on Computer Vision, pages 694–

  8. [8]

    Johnson and Joram Lindenstrauss

    William B. Johnson and Joram Lindenstrauss. Extensions of Lipschitz mappings into a Hilbert space. InConference in Modern Analysis and Probability, volume 26 ofCon- temporary Mathematics, pages 189–206. American Mathematical Society, 1984

  9. [9]

    Gradient-based learn- ing applied to document recognition.Proceedings of the IEEE, 86(11):2278–2324, 1998

    Yann LeCun, L´ eon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learn- ing applied to document recognition.Proceedings of the IEEE, 86(11):2278–2324, 1998

  10. [10]

    Gabriel Peyr´ e and Marco Cuturi.Computational Optimal Transport, volume 11. 2019

  11. [11]

    Springer, 2008

    C´ edric Villani.Optimal Transport: Old and New. Springer, 2008. A Architecture Details Table 2: CNN encoder architecture shared by all three models. Layer Operation Output shape Flat dim Conv1 Conv(1→8,5×5) + ReLU + MaxPool(2) (8,14,14) 1568 Conv2 Conv(8→16,3×3) + ReLU + MaxPool(2) (16,7,7) 784 Conv3 Conv(16→32,3×3) + ReLU + MaxPool(2) (32,3,3) 288 FC1 L...