How Transformers Learn Causal Structure with Gradient Descent
read the original abstract
The incredible success of transformers on sequence modeling tasks can be largely attributed to the self-attention mechanism, which allows information to be transferred between different parts of a sequence. Self-attention allows transformers to encode causal structure which makes them particularly suitable for sequence modeling. However, the process by which transformers learn such causal structure via gradient-based training algorithms remains poorly understood. To better understand this process, we introduce an in-context learning task that requires learning latent causal structure. We prove that gradient descent on a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. The key insight of our proof is that the gradient of the attention matrix encodes the mutual information between tokens. As a consequence of the data processing inequality, the largest entries of this gradient correspond to edges in the latent causal graph. As a special case, when the sequences are generated from in-context Markov chains, we prove that transformers learn an induction head (Olsson et al., 2022). We confirm our theoretical findings by showing that transformers trained on our in-context learning task are able to recover a wide variety of causal structures.
This paper has not been read by Pith yet.
Forward citations
Cited by 7 Pith papers
-
Why Muon Outperforms Adam: A Curvature Perspective
Muon outperforms Adam by reducing curvature penalty via lower Normalized Directional Sharpness, as shown via Taylor approximation on LLM training and proven on stylized quadratic problems with heterogeneous curvature.
-
A Global Characterization of $f$-Divergences Yielding PSD Mutual-Information Matrices
Pairwise f-mutual information matrices are positive semi-definite for all finite-alphabet distributions exactly when the f generator has a power series with all nonnegative coefficients that converges on the positive reals.
-
Fast Wasserstein rates for estimating probability distributions of probabilistic graphical models
Smoothness assumptions on graphical model kernels produce Wasserstein estimation rates determined by local graph structure rather than ambient dimension.
-
Rigorous uncertainty quantification of probabilistic AI weather forecasts with conformal prediction
Online conformal prediction post-processing guarantees calibrated uncertainty coverage for GenCast, NeuralGCM, and AIFS-ENS forecasts of temperature and precipitation including extremes.
-
CausalDetox: Causal Head Selection and Intervention for Language Model Detoxification
CausalDetox identifies minimal attention heads causally linked to toxicity via Probability of Necessity and Sufficiency, then applies targeted inference-time steering or fine-tuning to reduce toxic generation while pr...
-
Visual prompting reimagined: The power of the Activation Prompts
Activation prompts on intermediate layers outperform input-level visual prompting and parameter-efficient fine-tuning in accuracy and efficiency across 29 datasets.
-
Provable Knowledge Acquisition and Extraction in One-Layer Transformers
In a stylized one-layer transformer, pre-training encodes factual knowledge via relation-specific feature directions and attention patterns; fine-tuning extracts it through a relation-covering mechanism that succeeds ...
discussion (0)
Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.