PyTorchのモデル保存と読み込み

1. 概要

今回はPyTorchでモデルを学習させた後に、別のプログラムで学習させたモデルを使えるように

学習済みモデルの保存と読み込み方法を紹介します。

2. モデルの保存と読み込み

まずはモデルの保存です。

公式ドキュメントによると、慣例によりPyTorchファイルは「.pt」「.pth」拡張子で保存するみたいです。

import torch

# 保存
save_path = './weights.pth'
torch.save(net.state_dict(), save_path)

モデルをtorch.saveで直接保存することもできますが、

公式ドキュメントによると互換性の理由からモデルを直接保存するよりも、

state_dict()で辞書化して保存することが推奨されています。

pytorch.org

モデルは次のように読み込みます。

# 読み込み
load_path = './weights.pth'
load_weights = torch.load(load_path)
net.load_state_dict(load_weights)

これでモデルの保存と読み込みができます。



3. GPUで学習してCPUで読み込み

モデルを読み込む際に注意点として、GPU上で保存されたファイルをCPU上で読み込む場合は、

map_locationを使用する必要があります。

次のように読み込みます。

# GPU上で保存された重みをCPU上で読み込む場合
load_path = './weights.pth'
load_weights = torch.load(load_path, map_location={'cuda:0': 'cpu'})
net.load_state_dict(load_weights)