Do Machine Learning Models Memorize or Generalize?
2021年、研究者たちは玩具的な課題で一連の小さなモデルを訓練している最中に、驚くべき発見をしました。訓練を長く続けたところ、学習データを丸暗記していたモデルが、未見の入力でも正しく振る舞う汎化へと突然切り替わるケースが見つかったのです。学習データへの当てはまりが終わってからかなり経って、汎化が唐突に起こるこの現象は grokking(グロッキング)と呼ばれ、大きな注目を集めました。
より複雑なモデルでも、学習を長く続ければ突然汎化が起こるのでしょうか。大規模言語モデル(LLM)は世界について深い理解を持っているように見える一方で、訓練に使った膨大なテキストから暗記した断片を繰り返し出力しているだけかもしれません。汎化しているのか、ただ記憶を吐き出しているだけなのかを、どのように見分ければよいのでしょうか。
この記事では、ごく小さなモデルの学習ダイナミクスを観察し、モデルが見つけた解をリバースエンジニアリングします。その過程で、急速に発展しているメカニスティック解釈(mechanistic interpretability)の世界を体験してみましょう。今日の巨大なモデルにこれらの手法をどのように適用すべきかはまだ明確ではありませんが、小さな例から始めることで直感を養い、LLM に関する重要な問いに答えていく足がかりになります。
モジュラー加算は、グロッキング研究におけるショウジョウバエのような存在です。上の折れ線グラフは、 を予測するよう訓練したモデルから得たものです。まず、すべての の組み合わせをランダムに訓練データとテストデータに分割します。何千ステップもの学習で訓練データを使ってモデルの出力を調整し、テストデータはモデルが一般的な解を獲得できているかを確認するためだけに用います。
モデルのアーキテクチャも同じくらい単純で、 という 24 ニューロンの一層 MLP です。モデルの全重みは下のヒートマップに表示してあり、上の折れ線グラフにマウスオーバーすると学習中の変化が確認できます。
モデルは、 のうち入力 と に対応する 2 列を取り出して足し合わせ、24 個の数値からなるベクトルを作ります。その後、このベクトルの負の値を 0 にし、最後に更新後のベクトルに最も近い の列を出力します。
学習初期の重みはノイズが多いものの、テストデータでの精度が上がりモデルが汎化へと
学習終盤の周期数でニューロンをグループ化し、それぞれを別の折れ線で描くとこの様子がさらに分かりやすくなります。
こうした周期パターンは、モデルが何らかの数学的構造を学んでいることを示唆しています。しかもテスト例が解け始めたタイミングで現れることから、汎化と密接に関係していると考えられます。では、モデルは なぜ 暗記的な解から離れるのでしょうか。そして、どのような 解が汎化をもたらしているのでしょうか。
この二つの問いを同時に解くのは簡単ではありません。そこで、どのような解が汎化なのかをあらかじめ知っている、さらに単純な課題を作り、モデルが最終的にそれを学ぶ理由を理解することにします。
1 と 0 から成る長さ 30 のランダムな列を用意し、先頭 3 桁に含まれる 1 の数が奇数かどうかを予測するようモデルを学習させます。例えば
ここでも 1 層の MLP を用い、1,200 個のシーケンスを固定バッチで学習させます。初期段階では訓練精度だけが向上し、モデルが訓練データを暗記していることがわかります。モジュラー加算のときと同様に、テスト精度はしばらくランダムなままですが、一般的な解を学ぶと急激に上昇します。
この簡略化された例では、なぜこうなるのかが分かりやすくなります。学習中、モデルには二つのことを同時に求めています。正しいラベルに高い確率を割り当てること(損失 を小さくする)と、重みの大きさを小さく保つこと(ウェイトディケイ )。モデルが汎化する直前には、重みを小さくする代わりに正しい出力に関する損失がわずかに増えるため、訓練損失は少し上昇します。
テスト損失が急激に下がるため、あたかもモデルが瞬間的に汎化へ切り替わったように見えます。しかし学習を通じて重みを追っていくと、その多くが二つの解の間を滑らかに遷移していることが分かります。最後に残った紛らわしい桁への重みがウェイトディケイによって刈り取られた瞬間に、急激な汎化が起きるのです。
グロッキングは偶発的な現象であり、モデルサイズやウェイトディケイ、データサイズなどのハイパーパラメータが適切でなければ生じません。ウェイトディケイが小さすぎると、モデルは訓練データへの過学習から抜け出せません。少し増やすと、いったん暗記した後で汎化へと押し出されます。さらに強くするとテスト損失と訓練損失が同時に下がり、モデルはいきなり汎化に到達します。逆に強すぎると何も学べなくなります。
下の図では、1 と 0 の課題についてハイパーパラメータを変えながら 1,000 以上のモデルを学習させました。学習結果にはノイズがあるため、各ハイパーパラメータ設定につき 9 個のモデルを訓練しています。
この少し人工的な 1 と 0 の課題では、暗記と汎化を意図的に引き起こせることが分かりました。それでは、モジュラー加算ではなぜ同じことが起きるのでしょうか。一層の MLP がモジュラー加算をどのように解いているのかを、解釈しやすい汎化解を構成することで理解してみましょう。
というモジュラー演算の問題は本質的に周期的で、合計が 67 を超えるたびに答えが巻き戻ります。数学的には、 と を円周上に並べるとイメージしやすく、汎化したモデルの重みにも周期的なパターンが現れていました。この性質を利用しているのかもしれません。
そこで、あらかじめ問題を解きやすくするために、各入力 について と を計算し、 と を円周上に配置する埋め込み行列を用意します。
次に、この一層 MLP で と を学習させます。
わずか 5 個のニューロンでも、このモデルは完全な精度の解にたどり着きます。
学習済みパラメータを眺めると、すべてのニューロンがほぼ同じノルムに
の円周上で隣り合うニューロン同士を結ぶと面白いパターンが現れます。 が の 2 倍の速さで円を回っているのです。
この解がどのように機能しているのか詳しく理解する必要はありません(2 倍の回転によって と のような入力を同じ場所に写せる仕組みは、付録Aで解説しています)。重要なのは、20 個のパラメータでモジュラー加算を解く構成を見つけたという点です。では、冒頭で扱った 3,216 パラメータのモデルにも同じアルゴリズムが潜んでいるのでしょうか。そして大きなモデルは、暗記から汎化へなぜ切り替わるのでしょうか。
こちらが最初に紹介した のモデルです。周期性をあらかじめ仕込まず、ゼロから学習させています。
先ほど構築した解では が 1 周するだけでしたが、このモデルでは複数の周波数が現れています。
そこで、離散フーリエ変換(DFT)を使って周波数成分を取り出しました。入力全体にわたって学習された周期パターンを因数分解し、先ほどの構築済み解における と に相当するものを得ます。各ニューロンについて、1 から 33 までのすべての周期に対する と の値が得られます。上の波形チャートは、全周波数のうち最大の と を持つものを基準にニューロンをグループ分けした結果です。
1 と 0 の課題と同様に、ウェイトディケイはモデルが
最終的に学習された周波数ごとにニューロンをまとめ、各ニューロンの DFT の と 成分をプロットすると、構築済み解で見たのと同じ星形が現れます。
この学習済みモデルは、構築済み解と同じアルゴリズムを使っているのです! 下図では各周波数のニューロンが出力にどのように寄与しているかを示しており、 を計算している様子がわかります。
ウェイトディケイのペナルティを受ける大きな重みを使わずに損失を下げるために、モデルはいくつかの周波数を組み合わせ、建設的干渉を活用しています。周波数 4、5、7、26 に特別な意味があるわけではありません。下の他の学習例をクリックすると、このアルゴリズムの様々なバリエーションが学習される様子が見られます。
ここまでで、一層 MLP がモジュラー加算を解くメカニズムと、それが学習中にどのように現れるかを理解できました。それでもなお、暗記と汎化を巡る興味深い疑問が多く残されています。
先ほど可視化したモデル、すなわち をそのまま訓練しても、ウェイトディケイを加えただけではモジュラー演算で汎化には至りません。少なくともどちらかの行列を分解する必要があります。
離散フーリエ変換を行った後の汎化解はスパースでしたが、縮約した行列はノルムが大きくなりました。これは や に直接ウェイトディケイをかけても、この課題に適した帰納バイアスにはならないことを示唆しています。
一般論として、ウェイトディケイはさまざまなモデルを訓練データの暗記から遠ざけます。過学習を避ける他の手法としては、ドロップアウトやモデル規模の縮小、あるいは数値的に不安定な最適化アルゴリズムさえ挙げられます。これらの手法は複雑で非線形な相互作用を持つため、どの手法が最終的に汎化を引き出すかを事前に予測するのは困難です。例えば ではなく を縮約すると、うまくいく設定もあれば、かえって悪化する設定もありました。
一つの仮説として、訓練集合を暗記する方法のほうが、汎化解よりもずっと多く存在する可能性があります。正則化がない、もしくは弱い場合には、統計的に暗記が先に起こりやすいというわけです。ウェイトディケイのような正則化は、スパースな解を好み、密な解を嫌うなど、特定の解を優先する働きをします。
最近の研究では、汎化は構造化された表現と関連していると示唆されています。しかしこれは必要条件ではありません。対称な入力を持たない MLP のバリエーションでは、モジュラー加算を解く際に円形とは異なる表現を学ぶこともあります。私たちも、構造化された表現があっても汎化の十分条件ではないことを観察しました。ウェイトディケイなしで学習した小さなモデルは、最初こそ汎化しますが、周期的な埋め込みを使った暗記へと切り替わってしまいました。
ハイパーパラメータによっては、汎化から暗記へ、再び汎化へと行き来するモデルさえ見つかります。
現実のタスクで訓練された大きなモデルでもグロッキングは起きるのでしょうか。これまでの観測では、小さなトランスフォーマーや MLP を使ったアルゴリズム的タスクでグロッキングが報告されています。その後、特定のハイパーパラメータ範囲で、画像・テキスト・表データを含むより複雑なタスクでもグロッキングが確認されました。多様なタスクをこなせる巨大モデルでは、学習中に複数の事柄を異なる速度でグロッキングしている可能性もあります。
グロッキングが起こる前に予測しようという試みも有望です。汎化解に関する知識やデータ全体の構造に関する理解を必要とするものもありますが、訓練損失の解析だけに基づく手法もあり、大きなモデルにも適用できるかもしれません。いつモデルが暗記した情報を復唱しているのか、いつより豊かな表現を使っているのかを判別するツールが将来生まれることを期待しています。
モジュラー加算の解を理解するだけでも簡単ではありませんでした。では、より大きなモデルを理解する希望はあるでしょうか。ここで行った 20 パラメータモデルや、さらに単純なブールパリティ問題への寄り道のように、次のような道筋が考えられます。1) より強い帰納バイアスと少ない自由度を持つ単純なモデルを学習する。2) それらを用いて大きなモデルの謎めいた部分を説明する。3) 必要に応じて繰り返す。大規模モデルの理解を深めるうえで実りあるアプローチであり、巨大モデルを用いて小さなモデルを説明したり、内部表現を分解したりする研究を補完するものだと私たちは考えています。さらに、こうしたメカニスティックな解釈のアプローチは、時間をかければニューラルネットワークが学習したアルゴリズムを見つけ出す作業自体を容易にしたり自動化したりするパターンの発見につながるかもしれません。
本記事の制作にあたり、Ardavan Saeedi、Crystal Qian、Emily Reif、Fernanda Viégas、Kathy Meier-Hellstern、Mahima Pushkarna、Minsuk Chang、Neel Nanda、Ryan Mullins の協力に感謝します。
2 つの円形埋め込みと完全に線形なモデルを組み合わせれば、 をほぼ計算できます。
これでうまくいきますが、少しズルをしていることに気づくでしょうか。unembed が円を 2 周してしまっています。出力したいのは「
そこで を組み込み、重複した出力を解消します。
こうして円を折り返すように構成することで、モデルは「
この構成を式で表すと次のようになります。
法 、ニューロン(方向)数 を等間隔に配置すると、
興味深いことに、この円形構成には少し歪みがあり、厳密解にはなりません。
が示したように、活性化関数として の代わりに を用いると、理論的に厳密な解になります。
簡単のため、( における各数の角度間隔)と ( における各ニューロンの角度間隔)と置きます。
を 次元ベクトル として書き直すと、
となります。これは上式の に と の定義を代入し、 という三角恒等式を用いた結果です。
ここから次のことが証明できます。
と という二つの三角恒等式を用いると、
ここで、 が円周上で等間隔になるとき であることに注意します。第 1 項と第 3 項の総和はそれぞれ 、 ずつ円周を回り、 のとき第 1 項の総和は 0、 のとき第 3 項の総和も 0 になります。したがって、
最初の係数は入力に依存しない正の定数なので、式が最大になるのは が最大となるとき、すなわち のときです。
つまり、ウェイトディケイ付きの 活性化というごく一般的なモデル設定でも、スパースな離散フーリエ変換と 活性化で得られる厳密解に十分近い帰納バイアスが働き、汎化の方向へ押し出されます。ただし訓練データを暗記する余地も残されている、というわけです。
モジュラー加算では、2 つの入力 と 、そして法 を扱います。目的は、 を で割った余りを求めることです。 この種の加算は時計の針に例えられることが多く、時刻を足し合わせるときに 12 を法として結果を報告する(例:8 時から 5 時間後は 1 時)ためです。 モジュラー加算は一見単純で、実際に単純です。数千個のモデルを簡単に学習させることができ、神経科学におけるショウジョウバエのように扱えます。すなわち、コネクトーム をシナプス単位で抽出できるほど小さく、それでいてシステム全体について新しい洞察を与えてくれるのです。内部をすべて可視化すれば、小さなモデルの挙動を十分理解できます。
67 という数字に特別な意味はありません。グロッキングを示すための候補はいくつもありますが、67 は課題が簡単すぎず、可視化が煩雑になりすぎないちょうど良い値です。
モデルはクロスエントロピー損失と AdamW、全データバッチで学習させています。正則化の節とトレーニング用 Colab に詳細があります。 MLP に馴染みがない場合は、playground.tensorflow.org が良いスタート地点です。 記法を少し補足しておきます。 と の各列は 0 から 66 までの数字を表します。 と は入力をワンホットで符号化したもので、それぞれ から 1 列を選びます。 は負の数を 0 に置き換える関数で、 を書き換えたものです。
ここでは出力が 0 か 1 だけなので、 は 1 列で構いません。モジュラー加算ではすべての出力値に対応する列が必要でした。 また、 の最後の列は常に 1 に固定し、バイアス項として働かせています。
「A Tale of Two Circuits: Grokking as Competition of Sparse and Dense Subnetworks」の付録 D に、4 つのニューロンで得られる汎化解の説明があります。
これまでは 精度、すなわち正しいラベルが最も確からしいと判定された割合を示してきました。実際の学習では、微分可能な目的関数を最適化するのが一般的です。本記事のすべてのモデルは クロスエントロピー損失を用いており、高い確率で誤答した場合に大きなペナルティが課されます。 損失の定義によってはウェイトディケイなどの正則化項を含む場合もありますが、ここに掲載している損失グラフはクロスエントロピーの成分のみを描いています。
1 と 0 の課題では L1 ウェイトディケイ を使用しました。
より一般的なのは L2 ウェイトディケイ で、多数の小さな重みを生み出すため、この課題では冗長なニューロンが現れやすくなります。

訓練データでは高い性能を示すもののテストデータでは低い性能しか発揮できない状態を過学習と呼びます。本記事で暗記しているモデルがまさにその例です。一般に、単純なモデルほど過学習しにくく、粗い判断規則しか持てないぶん、より多くの一般化を強いられます。ただしモデルがタスクに対して単純すぎると、求められる判断規則を学習できない恐れがあります。研究者は、パラメータ数を減らしたり、ウェイトディケイでパラメータの大きさを抑えたりするなど、さまざまな手法でモデルを単純化します。
と を計算すると、単位円上に等間隔の点が得られます。
を単位円に配置すると次のようになります。

離散フーリエ変換は、値の列を正弦波と余弦波に分解することで(今回であれば特定のニューロンの重み)、周期性を解析する手法です。関数が周期的であればあるほど、正弦波と余弦波で表現しやすくなり、DFT の出力はスパースになります。
可視化しやすいよう、最終的な周波数と位相に基づいてニューロンの並び替えを行っています。
モデルは、入力に対するニューロンの活性と の内積を計算し、ソフトマックスを取ることで確率を生成します。特定の周波数に属するニューロンの活性だけで内積を計算すると、その周波数グループがどの出力を高めたり抑えたりしているかが分かります。 付録Aにあるように、これらのロジットが波形になる理由は、各周波数グループが、自分たちの周波数で構成された 上の全ての数値に対して、正解がどれだけ近いかを出力しているからです。
ここで示したモデルはどちらも非常に小さなものです。下段のモデルでは、最終的に汎化するようハイパーパラメータを調整しています。局所解から抜けられるよう少し大きめのモデルにし、訓練データを増やして低損失の暗記解を見つけにくくし、ウェイトディケイも適用しています。
Grokking: Generalization Beyond Overfitting On Small Algorithmic Datasets Power, A., Burda, Y., Edwards, H., Babuschkin, I., & Misra, V. (2022). arXiv preprint arXiv:2201.02177.
Omnigrok: Grokking Beyond Algorithmic Data Liu, Z., Michaud, E. J., & Tegmark, M. (2022, September). In The Eleventh International Conference on Learning Representations.
A Toy Model of Universality: Reverse Engineering How Networks Learn Group Operations Chughtai, B., Chan, L., Nanda, N. (2023). International Conference on Machine Learning.
The Clock and the Pizza: Two Stories in Mechanistic Explanation of Neural Networks Zhong, Z., Liu, Z., Tegmark, M., & Andreas, J. (2023). arXiv preprint arXiv:2306.17844.
Hidden Progress in Deep Learning: SGD Learns Parities Near the Computational Limit Boaz Barak, Benjamin L. Edelman, Surbhi Goel, Sham Kakade, Eran Malach, Cyril Zhang. (2022) Advances in Neural Information Processing Systems, 35, 21750-21764.
Grokking modular arithmetic Andrey Gromov (2023). arXiv preprint arXiv:2301.02679.
On the Dangers of Stochastic Parrots: Can Language Models Be Too Big?🦜 Bender, E. M., Gebru, T., McMillan-Major, A., & Shmitchell, S. (2021, March). In Proceedings of the 2021 ACM conference on fairness, accountability, and transparency (pp. 610-623).
Emergent World Representations: Exploring a Sequence Model Trained on a Synthetic Task Li, K., Hopkins, A. K., Bau, D., Viégas, F., Pfister, H., & Wattenberg, M. (2022, September). In The Eleventh International Conference on Learning Representations.
Mechanistic Interpretability, Variables, and the Importance of Interpretable Bases Olah, C., 2022. Transformer Circuits Thread.
Progress Measures for Grokking via Mechanistic Interpretability Nanda, N., Chan, L., Lieberum, T., Smith, J., & Steinhardt, J. (2022, September). In The Eleventh International Conference on Learning Representations.
A Tale of Two Circuits: Grokking as Competition of Sparse and Dense Subnetworks William Merrill, Nikolaos Tsilivis, Aman Shukla. (2023). arXiv preprint arXiv:2303.11873.
Unifying Grokking and Double Descent Davies, X., Langosco, L., & Krueger, D. (2022, November). In NeurIPS ML Safety Workshop.
Double Descent Demystified: Identifying, Interpreting & Ablating the Sources of a Deep Learning Puzzle Rylan Schaeffer, R., Khona, M., Robertson, Z., Boopathy, A., Pistunova, K., Rocks, J., Rani Fiete, I., & Koyejo, O. (2023). arXiv preprint arXiv:2303.14151.
The Slingshot Mechanism: An Empirical Study of Adaptive Optimizers and the Grokking Phenomenon Thilak, V., Littwin, E., Zhai, S., Saremi, O., Paiss, R., & Susskind, J. (2022). arXiv preprint arXiv:2206.04817.
Towards Understanding Grokking: An Effective Theory of Representation Learning Liu, Z., Kitouni, O., Nolte, N. S., Michaud, E., Tegmark, M., & Williams, M. (2022). Advances in Neural Information Processing Systems, 35, 34651-34663.
The Goldilocks Zone: Towards Better Understanding of Neural Network Loss Landscapes Fort, S., & Scherlis, A. (2019, July). In Proceedings of the AAAI conference on artificial intelligence (Vol. 33, No. 01, pp. 3574-3581).
The Quantization Model of Neural Scaling Eric J. Michaud, Ziming Liu, Uzay Girit, Max Tegmark, O. (2023). arXiv preprint arXiv:2303.13506.
Grokking of Hierarchical Structure in Vanilla Transformers Murty, S., Sharma, P., Andreas, J., & Manning, C. D. (2023). arXiv preprint arXiv:2305.18741.
Predicting Grokking Long Before it Happens: A Look Into the Loss Landscape of Models Which Grok Notsawo Jr, P., Zhou, H., Pezeshki, M., Rish, I., & Dumas, G. (2023). arXiv preprint arXiv:2306.13253.
Language models can explain neurons in language models Bills, S., Cammarata, N., Mossing, D., Tillman, H., Gao, L., Goh, G., Sutskever, I., Leike, J., Wu, J., & Saunders, W. 2023. OpenAI Blog
Does Circuit Analysis Interpretability Scale? Evidence from Multiple Choice Capabilities in Chinchilla Tom Lieberum, Matthew Rahtz, János Kramár, Neel Nanda, Geoffrey Irving, Rohin Shah, Vladimir Mikulik (2023). arXiv preprint arXiv:2307.09458.
Toy Models of Superposition Elhage, N., Hume, T., Olsson, C., Schiefer, N., Henighan, T., Kravec, S., Hatfield-Dodds, Z., Lasenby, R., Drain, D., Chen, C., Grosse, R., McCandlish, S., Kaplan, J., Amodei, D., Wattenberg, M. and Olah, C., 2022. Transformer Circuits Thread.
The Connectome of an Insect Brain Winding, M., Pedigo, B. D., Barnes, C. L., Patsolic, H. G., Park, Y., Kazimiers, T., … & Zlatic, M. (2023). Science, 379(6636), eadd9330.
Multi-Scale Feature Learning Dynamics: Insights for Double Descent Pezeshki, M., Mitra, A., Bengio, Y., & Lajoie, G. (2022, June). In the International Conference on Machine Learning (pp. 17669-17690). PMLR.
Superposition, Memorization, and Double Descent Henighan, T., Carter, S., Hume, T., Elhage, N., Lasenby, R., Fort, S., Schiefer, N., and Olah, C., 2023. Transformer Circuits Thread.