pith. sign in

arxiv: 2508.18168 · v3 · submitted 2025-08-25 · 💻 cs.CL

Improving End-to-End Training of Retrieval-Augmented Generation Models via Joint Stochastic Approximation

Pith reviewed 2026-05-18 21:14 UTC · model grok-4.3

classification 💻 cs.CL
keywords retrieval-augmented generationjoint stochastic approximationend-to-end trainingdiscrete latent variablesgradient estimationquestion answeringknowledge-grounded dialogs
0
0 comments X

The pith

Joint stochastic approximation enables better end-to-end training of retrieval-augmented generation models by reducing gradient variance.

A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.

Retrieval-augmented generation models consist of a retriever and generator connected in series, but end-to-end training requires marginalizing over discrete latent variables that represent relevant passages from a knowledge base. Traditional top-K marginalization introduces bias in the gradients while variational RAG produces high-variance estimates. The paper develops joint stochastic approximation, a stochastic extension of the EM algorithm, to handle this marginalization more effectively in a method called JSA-RAG. Experiments on five datasets for open-domain question answering and knowledge-grounded dialogs demonstrate that JSA-RAG significantly outperforms both vanilla RAG and VRAG. The gains are linked to improved generation quality, more accurate retrieval, and lower-variance gradient estimates.

Core claim

JSA-RAG applies the joint stochastic approximation algorithm to train retrieval-augmented generation models end-to-end by estimating gradients over discrete latent passages with lower variance than top-K marginalization or variational methods, resulting in better performance on open-domain question answering and knowledge-grounded dialog tasks across five datasets.

What carries the argument

Joint stochastic approximation (JSA) algorithm, a stochastic extension of the EM algorithm specialized for estimating discrete latent variable models.

Load-bearing premise

The joint stochastic approximation algorithm produces gradient estimates with sufficiently low variance and without new biases that would prevent outperformance over top-K marginalization and VRAG.

What would settle it

A controlled experiment that directly measures and compares the variance of gradient estimates produced by JSA-RAG versus VRAG on one of the evaluation datasets while holding all other factors fixed.

read the original abstract

Retrieval-augmented generation (RAG) has become a widely recognized paradigm to combine parametric memory with non-parametric memories. An RAG model consists of two serial connecting components (retriever and generator). A major challenge in end-to-end optimization of the RAG model is that marginalization over relevant passages (modeled as discrete latent variables) from a knowledge base is required. Traditional top-K marginalization and variational RAG (VRAG) suffer from biased or high-variance gradient estimates. In this paper, we propose and develop joint stochastic approximation (JSA) based end-to-end training of RAG, which is referred to as JSA-RAG. The JSA algorithm is a stochastic extension of the EM (expectation-maximization) algorithm and is particularly powerful in estimating discrete latent variable models. Extensive experiments are conducted on five datasets for two tasks (open-domain question answering, knowledge-grounded dialogs) and show that JSA-RAG significantly outperforms both vanilla RAG and VRAG. Further analysis shows the efficacy of JSA-RAG from the perspectives of generation, retrieval, and low-variance gradient estimate.

Editorial analysis

A structured set of objections, weighed in public.

Desk editor's note, referee report, simulated authors' rebuttal, and a circularity audit. Tearing a paper down is the easy half of reading it; the pith above is the substance, this is the friction.

Referee Report

2 major / 2 minor

Summary. The paper proposes JSA-RAG, an end-to-end training method for retrieval-augmented generation models that applies joint stochastic approximation (a stochastic EM extension) to handle marginalization over discrete latent passages. It claims that this yields lower-variance gradient estimates than top-K marginalization or VRAG, resulting in significant outperformance on five datasets across open-domain question answering and knowledge-grounded dialog tasks, with supporting analysis on generation quality, retrieval effectiveness, and gradient variance.

Significance. If the empirical gains and low-variance properties hold under rigorous controls, the work would offer a practical advance in training RAG models with discrete retrieval latents, potentially improving stability and performance in knowledge-intensive NLP applications by better integrating retriever and generator optimization.

major comments (2)
  1. [§3] §3 (Method): The joint stochastic approximation update rules are described at a high level, but the manuscript does not provide an explicit bias/variance derivation for the estimator when sampling couples retriever and generator parameters; without this, it is unclear whether finite-sample JSA remains unbiased relative to the true marginal or delivers materially lower variance than VRAG at equal sample budgets.
  2. [§4.2] §4.2 (Experiments, main results table): The reported gains over VRAG on the five datasets lack accompanying gradient-norm or variance statistics, and it is not stated whether VRAG baselines used the same number of samples per update; this leaves open the possibility that observed improvements stem from optimization dynamics rather than the claimed variance reduction.
minor comments (2)
  1. [Table 1] Table 1 and Figure 3: axis labels and legend entries should explicitly note the number of samples used for each method to enable direct variance comparison.
  2. [§2] §2 (Related Work): The distinction between JSA and prior stochastic EM variants for discrete latents could be sharpened with a short equation contrasting the joint update against separate E/M steps.

Simulated Author's Rebuttal

2 responses · 0 unresolved

We are grateful to the referee for the detailed and constructive comments. We address each major comment point by point below, providing clarifications and committing to revisions that strengthen the presentation of the method and experiments.

read point-by-point responses
  1. Referee: [§3] §3 (Method): The joint stochastic approximation update rules are described at a high level, but the manuscript does not provide an explicit bias/variance derivation for the estimator when sampling couples retriever and generator parameters; without this, it is unclear whether finite-sample JSA remains unbiased relative to the true marginal or delivers materially lower variance than VRAG at equal sample budgets.

    Authors: We thank the referee for highlighting this. The JSA procedure is a stochastic approximation to the EM algorithm for discrete latent variable models and inherits unbiasedness properties from the underlying framework when the sampling is properly coupled. Nevertheless, we agree that an explicit finite-sample bias and variance derivation for the joint retriever-generator estimator would improve clarity. In the revised manuscript we will add a dedicated paragraph (or short appendix subsection) deriving that the estimator is unbiased with respect to the true marginal likelihood gradient and that its variance is strictly lower than that of VRAG under an equal number of samples per update. The derivation will follow the standard stochastic approximation analysis while specializing the coupling step to the RAG setting. revision: yes

  2. Referee: [§4.2] §4.2 (Experiments, main results table): The reported gains over VRAG on the five datasets lack accompanying gradient-norm or variance statistics, and it is not stated whether VRAG baselines used the same number of samples per update; this leaves open the possibility that observed improvements stem from optimization dynamics rather than the claimed variance reduction.

    Authors: We appreciate the request for tighter experimental controls. The VRAG baselines were in fact configured with exactly the same number of samples per gradient step as JSA-RAG; this detail was stated in the experimental setup paragraph but not repeated in the caption of the main results table. While gradient-variance analysis appears in §4.3, we acknowledge that the primary performance table does not juxtapose these statistics. We will revise §4.2 to (i) explicitly note the matched sample budget in the table caption and (ii) add a compact column or supplementary table reporting average gradient norm and variance for both methods across the five datasets. These additions will make it straightforward to attribute performance differences to the variance-reduction property of the JSA estimator. revision: yes

Circularity Check

0 steps flagged

JSA-RAG applies an existing stochastic EM extension to RAG without self-referential derivation or fitted inputs renamed as predictions

full rationale

The paper presents JSA-RAG as the direct application of the joint stochastic approximation algorithm (described as a known stochastic extension of EM) to end-to-end RAG training with discrete passage latents. No load-bearing step reduces by construction to a self-definition, a fitted parameter relabeled as a prediction, or a self-citation chain that supplies the uniqueness or ansatz. The claimed low-variance gradient estimates and outperformance are supported by experimental results on five datasets rather than by internal re-derivation of the JSA estimator itself. The derivation chain therefore remains self-contained against external benchmarks.

Axiom & Free-Parameter Ledger

0 free parameters · 1 axioms · 0 invented entities

The approach rests on standard assumptions from the EM algorithm and stochastic approximation literature applied to discrete latent variable models in RAG; no new free parameters or invented entities are described in the abstract.

axioms (1)
  • domain assumption Joint stochastic approximation is a stochastic extension of EM that is particularly powerful for estimating discrete latent variable models.
    Directly stated in the abstract as the basis for the proposed method.

pith-pipeline@v0.9.0 · 5740 in / 1146 out tokens · 44969 ms · 2026-05-18T21:14:31.420343+00:00 · methodology

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Lean theorems connected to this paper

Citations machine-checked in the Pith Canon. Every link opens the source theorem in the public Lean library.

What do these tags mean?
matches
The paper's claim is directly supported by a theorem in the formal canon.
supports
The theorem supports part of the paper's argument, but the paper may add assumptions or extra steps.
extends
The paper goes beyond the formal theorem; the theorem is a base layer rather than the whole result.
uses
The paper appears to rely on the theorem as machinery.
contradicts
The paper's claim conflicts with a theorem or certificate in the canon.
unclear
Pith found a possible connection, but the passage is too broad, indirect, or ambiguous to say the theorem truly supports the claim.