pith. sign in

arxiv: 2402.14735 · v2 · pith:T25E4UXXnew · submitted 2024-02-22 · 💻 cs.LG · cs.IT· math.IT· stat.ML

How Transformers Learn Causal Structure with Gradient Descent

classification 💻 cs.LG cs.ITmath.ITstat.ML
keywords causaltransformersgradientstructurein-contextlatentlearnlearning
0
0 comments X
read the original abstract

The incredible success of transformers on sequence modeling tasks can be largely attributed to the self-attention mechanism, which allows information to be transferred between different parts of a sequence. Self-attention allows transformers to encode causal structure which makes them particularly suitable for sequence modeling. However, the process by which transformers learn such causal structure via gradient-based training algorithms remains poorly understood. To better understand this process, we introduce an in-context learning task that requires learning latent causal structure. We prove that gradient descent on a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. The key insight of our proof is that the gradient of the attention matrix encodes the mutual information between tokens. As a consequence of the data processing inequality, the largest entries of this gradient correspond to edges in the latent causal graph. As a special case, when the sequences are generated from in-context Markov chains, we prove that transformers learn an induction head (Olsson et al., 2022). We confirm our theoretical findings by showing that transformers trained on our in-context learning task are able to recover a wide variety of causal structures.

This paper has not been read by Pith yet.

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Forward citations

Cited by 7 Pith papers

Reviewed papers in the Pith corpus that reference this work. Sorted by Pith novelty score.

  1. Why Muon Outperforms Adam: A Curvature Perspective

    cs.LG 2026-06 conditional novelty 7.0

    Muon outperforms Adam by reducing curvature penalty via lower Normalized Directional Sharpness, as shown via Taylor approximation on LLM training and proven on stylized quadratic problems with heterogeneous curvature.

  2. A Global Characterization of $f$-Divergences Yielding PSD Mutual-Information Matrices

    cs.IT 2026-01 unverdicted novelty 7.0

    Pairwise f-mutual information matrices are positive semi-definite for all finite-alphabet distributions exactly when the f generator has a power series with all nonnegative coefficients that converges on the positive reals.

  3. Fast Wasserstein rates for estimating probability distributions of probabilistic graphical models

    math.ST 2025-10 unverdicted novelty 7.0

    Smoothness assumptions on graphical model kernels produce Wasserstein estimation rates determined by local graph structure rather than ambient dimension.

  4. Rigorous uncertainty quantification of probabilistic AI weather forecasts with conformal prediction

    physics.ao-ph 2026-06 unverdicted novelty 6.0

    Online conformal prediction post-processing guarantees calibrated uncertainty coverage for GenCast, NeuralGCM, and AIFS-ENS forecasts of temperature and precipitation including extremes.

  5. CausalDetox: Causal Head Selection and Intervention for Language Model Detoxification

    cs.CL 2026-04 unverdicted novelty 6.0

    CausalDetox identifies minimal attention heads causally linked to toxicity via Probability of Necessity and Sufficiency, then applies targeted inference-time steering or fine-tuning to reduce toxic generation while pr...

  6. Visual prompting reimagined: The power of the Activation Prompts

    cs.CV 2026-04 unverdicted novelty 6.0

    Activation prompts on intermediate layers outperform input-level visual prompting and parameter-efficient fine-tuning in accuracy and efficiency across 29 datasets.

  7. Provable Knowledge Acquisition and Extraction in One-Layer Transformers

    cs.LG 2025-07 unverdicted novelty 6.0

    In a stylized one-layer transformer, pre-training encodes factual knowledge via relation-specific feature directions and attention patterns; fine-tuning extracts it through a relation-covering mechanism that succeeds ...