更新:2024/10/18
PyTorchにおけるモデルの保存と読み込み!2つの方法

目次
1. PyTorchにおけるモデルの保存と読み込み
PyTorchは、深層学習モデルの開発や学習に使用されているフレームワークです。PyTorchで、学習したモデルをデータを保存する方法は大きく分けて二つあります。
一つ目は、モデル全体を保存する方法です。二つ目は、モデルのパラーメーターを保存する方法です。
1.1. モデルの全体の保存・ロード
1.1.1. モデル全体の保存
モデル全体を保存するには、torch.save を使用して、保存します。
torch.save(model, 'model.pth')
1.1.2. モデル全体のロード
保存されたモデルを読み込むには、torch.load を使用します。
model = torch.load('model.pth')
1.2. モデルのパラーメーターの保存・ロード
1.2.1. モデルのパラメーターの保存
モデルのパラメーターを保存するには、state_dictを使用します。
torch.save(model.state_dict(), 'model_state_dict.pth')
1.2.2. モデルのパラメーターのロード
パラメーターを読み込むには、まず同じ構造のモデルのインスタンスを作成し、その後、load_state_dictを使用して保存されたパラメーターを読み込みます。
model = MyModel()
model.load_state_dict(torch.load('model_state_dict.pth'))
PR