StatsBeginner: 初学者の統計学習ノート

統計学およびR、Pythonでのプログラミングの勉強の過程をメモっていくノート。たまにMacの話題。

DataParallelでの複数GPUの並列化が上手くいかない(PyTorch)

単なる作業経過のメモです。
AWSで、gクラスのインスタンスのvCPU数上限緩和を申請したら通りまして、複数GPUのインスタンスが使えるようになりました。
そこでGPU4枚のインスタンスを立てて、以下のような情報を参考に、先日構築したTransformer翻訳機にとりあえずDataParallelのほうを適用してみたのですが、なかなかうまくいかず苦労しています。


DataParallel — PyTorch 2.0 documentation
Distributed Data Parallel — PyTorch master documentation
Pytorch高速化 (1)Multi-GPU学習を試す - arutema47's blog
pytorch DistributedDataParallel 事始め - Qiita
DDPによる学習時間の高速化を確認してみる | FORXAI | コニカミノルタ


最初、一番苦労したのは、DataParallelがバッチを自動的に4分割する際に、マスクの形状が乱れることです。Transformerのマスクの制御は、本質的にそんなに難しいわけではないはずなのですが、いろいろミスが出てエラーに繋がります。
とりあえず、Transformer本体の関数に対して外でつくったマスクを与えるのではなく、なるべく内部で生成するようにしたら、エラーは出なくなりました。


ところが学習をさせてみると、どうも1エポック目の1バッチ目から、forward時にすべての出力(テンソルの中の値)がnanになってしまう問題が発生しており、とりあえずいま作業時間がないので中断して放置してます。
もしかしたら、最初からDistributedDataParallelのほうを試したほうがいいのかもしれない。そっちのほうが、適用するために書く部分は多いんですが、PyTorch公式でもそっちがおすすめされていて、速度も速いらしい。


【追記】DistributedDataParallelを使ったらうまく行きました。DDPのほうは、Jupyter Lab等の対話環境から使うのには向いておらず、学習用のコードをスクリプト化して直接実行することが前提になるので、もともと対話環境でコードを書いていた場合はそれを再構成するのがめんどくさい……と感じたのですが、実際やってみたら大して面倒ではなかったし、動作も安定していて何よりです。