StatsBeginner: 初学者の統計学習ノート

統計学およびR、Pythonでのプログラミングの勉強の過程をメモっていくノート。たまにMacの話題。

TransformerやAttentionの分かりにくい点についてのメモ

 ChatGPTの「GPT」はGenerative Pretrained Transformerの略であり、TransformerというのはGoogleが2017年に発表した『Attention is all you need』という論文で提案されたディープラーニングの画期的なアーキテクチャで、その論文のタイトル通り「Attention」という機構だけを使い倒している点が特徴的である。


 ……という話は色々なところで紹介されているのだが、私のような素人が読むと誤解してしまうような記述も少なくない。
 「仕組みを分かっていない人がいい加減な説明をしている」というよりは、「仕組みをよく知っている人が、素人の気持ちをあまり考えずに説明している」という感じで、悪気は全くないと思うし誤解するのは不勉強なこちらが悪いのだが、読むほうとしては困ることも多い。
 そこで、にわか機械学習ユーザである自分がall you need論文や巷の解説を読んでいて理解に詰まってしまった点について、簡単にメモしておきたい。ただし以下は走り書きなので、他の人が読んで分かりやすいように配慮はしていないし(笑)、間違ってるかも知れません。


 All you need論文はこちら:
[1706.03762] Attention Is All You Need

 入門的な解説はこのシリーズが詳しいかな:
Transformers Explained Visually (Part 1): Overview of Functionality | by Ketan Doshi | Towards Data Science


「アテンション機構のみ」で構成されてるわけではない

 これは中身を知っている人にとっては当たり前すぎて「揚げ足取り」のように思われそうだが、よく知らない人が「トランスフォーマーはアテンションという仕組みだけを使った画期的な手法である」と聞けば、やはり誤解したり混乱したりするのではないだろうか。
 「Attention is all you need」というのは、要するに「RNN(や畳込み)を使わない」という意味だ*1
 トランスフォーマーの内部では、順伝播型のニューラルネットがアテンション機構と同じぐらいたくさん出てきて、そこで色々学習しているので、全く「アテンションのみ」ではない。そもそもトランスフォーマーの場合、狭義のアテンション部分は学習パラメータを持ってないし。

「重要な情報に注目する仕組み」では何も説明できていない

 Attentionは「注意」「注目」といった意味であり、アテンション機構は「入力データのどの部分に注目すればよいかを判断している」と説明されることが多い。言い換えると、データの中で「どの部分が重要か」を判断し、情報に重み付けをしているということだ。
 しかし、「重要な情報に注目して重み付けする」のはただの順伝播型ニューラルネットでも同じなので、「アテンションは重要な情報に注目する仕組みです!」と言われても、何が凄いのかさっぱり分からない。
 実際の意味合いとしては、ふつうのニューラルネットにベクトルを投入して学習されるパラメータは「ベクトル内の各要素についての重み付け」を表しているのに対し、アテンションがやっているのは、「ベクトルがたくさんならんだ系列データ」を扱う際に、ベクトル間の関係を重み付けする仕事という感じだが、なんかうまい言い方を考える必要があると思う。

Attentionには系列「全体」の意味を捉えるという役割もある

 アテンション機構は、系列内のとある要素(ベクトル)から見て、ほかの要素のどれが重要なもの(強い関連性を持つ)であるかを判断している。しかし同時にそれを通じて、「系列全体の情報を取り込む=混ぜ合わせる」という役割を果たしてもいる。この両面性こそが重要なのだと考えるべきだろう。
 「アテンションは重要な情報に注目する仕組みです」と言われてしまうと、情報全体の特定の部分だけを活用する仕組みのように思ってしまう人が多いのではないだろうか。それはそれで、ある意味では正しい。しかしこの処理は、「重要度に重みをつけながら、系列全体の要素間関係を捉える」処理だとも言えて、「部分に向かう」と「全体に向かう」の両方向な役割を持っているのである。

「エンコーダ-デコーダ」モデルから説明しないと何も分からない

 上記2点については、Transformerが出てくる前の背景としてRNNベースの「エンコーダ-デコーダ」モデルが頭に入っていると誤解はしないのだが、その説明をすっ飛ばしてしまうと、たぶんあまり仕組みが理解できないと思う。だからもちろん、背景から説明されていることも多いのだが、そこがショートカットされている説明も散見される。


 もともと、RNNベースのエンコーダ-デコーダモデルにおいては、

  1. 1つ目の単語から順次処理して次に送っていく過程で、最初のほうの単語の情報が薄まってしまう(長期記憶が弱い)
  2. エンコーダからデコーダに渡される文脈ベクトルが固定次元なので表現力に乏しい
  3. 逐次処理なので計算に時間がかかる

などの問題があった。で、エンコーダとデコーダをアテンション機構*2でつなげることで大域的な文脈を失わずに処理するという工夫が提案され、1つ目と2つ目の問題に対処することが可能になった。そしてトランスフォーマーは、それをさらに押し進めてRNNブロック自体を排除してしまい、セルフアテンション等を用いて系列内の単語間関係を捉えるようになっている。また、RNNがいなくなったことで、計算も並列化できて速くなったとされる。

クロスアテンションのQに投入するのは「1つ前の単語」

 トランスフォーマーには2種類のアテンションが使われていて、一つが「self-attention」、もう一つは「encoder-decoder attention」と呼ばれている。後者は、source-target attentionとかcross-attentionとか呼ばれることもある。以下では、短くて呼びやすいので後者を「クロスアテンション」と呼ぶことにする。


 セルフアテンションは、入力又は出力の各系列の内部で単語間の関係を捉えるのに使われており、クロスアテンションは、入力の系列に含まれる単語と出力の系列に含まれる単語の関係を捉えるのに使われている。セルフアテンションはトランスフォーマーで一番重要とも言える機構であり、何をしているのかは分かりやすい。
 一方、後者のクロスアテンションの仕組みは、たとえば「This is a pen」→「これはペンです」という翻訳において「pen」と「ペン」が強く関係することを表現しているのだと説明される。それはまぁそうなのだが、このイメージを強く持ってしまうと、実際の処理をみていくときに混乱する。というのも、デコーダのクロスアテンション層でQ(クエリ)に投入されるのは、出力しようとしている単語の「1つ前」の単語の情報だからだ*3
 つまり、「ペン」を生成するときは、「は」を表すベクトル*4をクエリとして、「this」「is」「a」「pen」のそれぞれとの関連の強さを計算し、その強さで重み付けしながらバリューを足し合わせたものがアテンション情報として使われる*5。このアテンション情報が「ペン」の生成に使われることになるので、「ペン」との対応関係をみていると説明しても全く間違いではないのだろうが、「ペン」をクエリにするわけではないことに留意を促さないと、わけがわからなくなる。

何が入ってきて何が出ていくのか

 トランスフォーマーはけっこう部品が多いので、各部品の入口と出口で何が流れていくのかをいちいちハッキリさせておかないと、頭がこんがらがる。出入りするのがベクトルなのか行列なのかとか、次元は何なのかとかを毎度確認しないと、自分の理解が合ってるのかどうか不安になるのだが、そこが述べられていない説明を目にすることは多い。


 で、結局のところ、エンコーダにせよデコーダにせよ、トランスフォーマーブロックの中では常に「単語数×埋め込み次元数」の行列が流れていくと考えておけばよいと思う*6
 embedding層の前では各行が語彙数次元のone-hotベクトルになっていたり、最後の出力時には語彙数次元の確率分布ベクトルだったりするのだが、トランスフォーマーブロックの中では、「単語数×埋め込み次元数」の行列の各行にひたすら色々な変換が施されていく感じになっている(ただし各行が独立に処理されるという意味ではなく、セルフアテンションする時には行間=単語間で互いに影響し合うことになる)。
 逆にいうと、RNNの隠れ状態のような文脈ベクトル、つまり文意を1つのベクトルで表したものは生成されない*7
 あと、デコーダについては、仮に出力文のj番目の単語を出力しようとしている時には、

  • 入力されるのは、BOSトークンからj-1番目の単語までの意味ベクトル
  • 出力されるのは、1番目の単語からj番目の単語までの意味ベクトル

 というふうに「1個右にズレる」という点も、確認しておかないと頭がこんがらがるかもしれない。

KとVは同じもの(セルフアテンションではQも同じ)

 アテンション層の「Q」「K」「V」はクエリ、キー、バリューなのだが、データベース用語のイメージをあまり持ちすぎないほうがいいと思う。とくにセルフアテンション層ではQもKもVも同じ行列が投入されるので、「データベースにクエリを投げて、キーで探してバリューを取ってくる」みたいなイメージを持っていると、何をしてるのかわけが分からなくなる。
 クロスアテンション層では、KとVはどちらも「エンコーダの出力行列」であり、これはすなわち、各単語の意味ベクトル(を色々変換しまくったもの)を重ねたもので、全く同じものがKとVに投入される。ただし、アテンション層の入り口に線形層があって、ここでKとVには別々の変換が施されることになるので、結果的には別物になる。なお、Qはデコーダ側の入力単語ベクトルを重ねたものである。
 セルフアテンション層に至っては、Qも同じものが投入される。つまり、全く同じ行列を3つの入口に入れて、それぞれ「別の線形変換をした上で」アテンション処理が行われるという流れになっている。

逐次処理がなくなったわけではない

 推論時のデコーダ部分では、最初にBOSを入力すると1つめの単語が出力され、その1つ目の単語を入力すると2つ目の単語が出力されるというように、逐次的なループ処理になっている。並列的に処理する方法も考案されているのだが、ループさせたほうが精度は高いらしい。
 つまり、RNNを取り除いたと言っても、完全に逐次処理がなくなったわけではない(エンコーダと、学習時のデコーダは並列処理になっている)。

マルチヘッド処理では埋め込みベクトル自体が分割されてる

 トランスフォーマーのアテンションは「マルチヘッド・アテンション」と呼ばれ、系列を構成するベクトル間の関係を「複数の観点で」捉えるようになっている。
 イメージでいうと、「今日は大学で二次試験が行われていた」という文において、「二次試験」という語は文法的には「行う」との間に動詞-目的語という強い関係を持っているが、「二次試験」と「大学」の間にも、場面や文脈上の重要な関係がある。
 この「複数の観点」の学習は、ヘッドを複数用意するだけでいい感じにやってくれるんだろうか?と一瞬思ったのだが、もちろんそんなわけはなく、単語の埋め込みベクトル自体をチョン切って個々のヘッドに与えている。だから、ある意味もはや「別のデータ」を扱っているとも言える。
 たぶん、この「チョン切り」にうまく適応するように、埋め込み層の処理も学習されていくので、埋め込みベクトル自体が意味不明な数字の羅列ではなく、ある観点で重要な特徴次元が近い位置に集まるんだろう。面白い。

*1:論文では「RNNや畳込みを使わず、セルフアテンションに全面的に依拠してデータを処理する、初めての系列変換モデルである」と書かれている。

*2:ちなみに当初のアテンションは、Transformerのアテンションとは仕組みがけっこう違う。

*3:ちなみにRNNベースのエンコーダ・デコーダモデルでも、デコーダ出力の1つ前の隠れ状態と入力系列の関係がアテンションとして計算される

*4:セルフアテンションで自分より前の単語との関係が情報として織りこまれるので、もはや純粋に「は」を表すベクトルであるとは言えないかも知れないが。

*5:正確にいうと、デコーダにはから出力単語の1個前までのベクトルを重ねた行列が投入され、(BOSの次にくる)文頭単語から出力単語までのベクトルを重ねた行列が出てくるようになっており、ここではその最後の行だけに着目している。

*6:マルチヘッド化する時に一時的にちょん切られるが、すぐに結合される。

*7:細かいことを言うと、BERTで文章の区切り部分に入るトークン(CLS)のベクトルなんかは、文全体の意味を表しているとも言われるが。