学習は問題ないのにテスト時にtorch.load・load_state_dictでCUDA out of memoryが出る場合

問題

  • Pytorchで、学習終了後にテストに移る際に最良のモデルをロードしようとするとCUDA out of memoryが発生する
  • VRAMに移動したモデルに対してチェックポイントからロードしようとするとCUDA out of memoryが発生する

原因

torch.loadが新しい重み情報をロードする際に一度VRAMに配置し、load_state_dictでmodelに適用されているため。

model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-13b-hf")
model.to("cuda")
# Train roop (Save best model)
model.load_state_dict(torch.load("model.pth")) # # Load best model

学習の際に最良のモデルを保存しておいて、上記のようにload_state_dictで最良のモデルをロードするようにする場合があると思います。この際に、どうやらtorch.loadが気を利かせてVRAM(saveした際のデバイス上)にロードしてくれているみたいですが、torch.loadload_state_dictなので、一時的に新しい重みと古い重みの両方がVRAMに乗ることになってしまいます。
しかし、LLMなどの巨大なモデルを利用していると、「学習後のmodelにある元々の重み情報」+「ロードしたい最良の重み情報」の合計がVRAMのサイズを超えてしまう場合もよくあると思います。

解決策

CPUでチェックポイントの重みをロードして、GPUに配置されたモデルに適用する。

model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-13b-hf")
model.to("cuda")
# Train roop (Save best model)
model.load_state_dict(torch.load("model.pth", map_location="cpu")) # Load best model