One Step of Gradient Descent is Provably the Optimal In-Context Learner with One Layer of Linear Self-Attention

Recent works have empirically analyzed in-context learning and shown that transformers trained on synthetic linear regression tasks can learn to implement ridge regression, which is the Bayes-optimal predictor, given sufficient capacity [Aky\"urek et al., 2023], while one-layer transformers wit...

Full description

Saved in:
Bibliographic Details
Published inarXiv.org
Main Authors Mahankali, Arvind, Hashimoto, Tatsunori B, Ma, Tengyu
Format Paper
LanguageEnglish
Published Ithaca Cornell University Library, arXiv.org 07.07.2023
Subjects
Online AccessGet full text

Cover

Loading…
More Information
Summary:Recent works have empirically analyzed in-context learning and shown that transformers trained on synthetic linear regression tasks can learn to implement ridge regression, which is the Bayes-optimal predictor, given sufficient capacity [Aky\"urek et al., 2023], while one-layer transformers with linear self-attention and no MLP layer will learn to implement one step of gradient descent (GD) on a least-squares linear regression objective [von Oswald et al., 2022]. However, the theory behind these observations remains poorly understood. We theoretically study transformers with a single layer of linear self-attention, trained on synthetic noisy linear regression data. First, we mathematically show that when the covariates are drawn from a standard Gaussian distribution, the one-layer transformer which minimizes the pre-training loss will implement a single step of GD on the least-squares linear regression objective. Then, we find that changing the distribution of the covariates and weight vector to a non-isotropic Gaussian distribution has a strong impact on the learned algorithm: the global minimizer of the pre-training loss now implements a single step of \(\textit{pre-conditioned}\) GD. However, if only the distribution of the responses is changed, then this does not have a large effect on the learned algorithm: even when the response comes from a more general family of \(\textit{nonlinear}\) functions, the global minimizer of the pre-training loss still implements a single step of GD on a least-squares linear regression objective.
ISSN:2331-8422