StatsBeginner: 初学者の統計学習ノート

統計学およびR、Pythonでのプログラミングの勉強の過程をメモっていくノート。たまにMacの話題。

PyTorch初心者のメモ

以下は、PyTorchでのニューラルネット構築について、学んだ基礎的事項のメモです。

  • nn.ModuleというのはTransformerを含めたニューラルネットワークの部品を意味していて、nn.Moduleの__call__メソッドはforwardメソッドを呼ぶようになっているので、nn.Moduleを継承して作られたインスタンスは、関数の形で呼び出されると自動的にテンソルをforwardして結果を出力する。
  • PyTorchの計算グラフは、1つのバッチを処理する過程で生成される一時的なものである(動的グラフ)。ネットワーク上を、あるバッチ化されたテンソル(torch.Tensorクラスのインスタンス)が流れていく際、様々な操作が順次加わってどんどん中間テンソルが生成されていくが、それらが全てgrad_fnという属性を持っている。このgrad_fnには各テンソルに対する操作の情報が保存されていて、その総体が計算グラフということになる(個々のテンソルはそのグラフのノード)。
  • backward(バックプロパゲーション)は、torch.Tensorのメソッドになっている。これはどう動くかというと、ネットワーク内をforward - forward - ...と辿って形を変えてきたデータテンソル(ここでは出力)と、答えのラベルを、torch.nn.CrossEntropyLoss等で作られた誤差関数に突っ込むことで、まず誤差テンソルが生成される。この誤差テンソルのbackwardメソッドを呼び出すと、誤差テンソルのgrad_fnが起点となって、計算グラフ内の各中間生成テンソルのgrad_fnに保存されている情報に従って、前のテンソル、前のテンソル...と辿っていき、計算グラフ内で勾配が計算される。
  • 勾配は、計算グラフ上で(というか計算グラフを利用して)計算されて、パラメータテンソルの.grad属性に格納される。ここで、データもテンソルだが、パラメータもテンソルである点に注意。データテンソルの.grad_fnはデータに対する操作の情報を、パラメータテンソルの.gradha勾配の情報を持っている。
  • 次にoptimizer.step()を行うと、各パラメータテンソルに格納された勾配に基づいて、パラメータの値を更新する。