今までの記事でPSPNetが実装できました、この記事からは実際にモデルを学習させていきます。
この記事ではファインチューニングを実施するにあたり、どのような方針で進めていくかの説明とファインチューニングを実施するための前準備を行います。
ディレクトリ構成としては、以下のような構成になっています。
parameters/train_model
の中に学習済みモデルを保存します。
notebook/01_ファインチューニング_準備偏.ipynb
に準備のためのスクリプトを実装していきます。
├─ data | └─ VOCdevkit | └─ VOC2012 | ├─ Annotations | ├─ ImageSets | ├─ JPEGImages | ├─ SegmentationClass | └─ SegmentationObject ├─ initmodel | └─ resnet50_v2.pth ├─ model | ├─ pspnet.py | └─ resnet.py ├─ notebook | └─ 01_ファインチューニング_準備偏.ipynb # 準備用のスクリプトを実装する ├─ parameters | └─ trained_model | ├─ train_epoch_100.pth # ダウンロードした学習済みモデルはここに保存する | └─ pspnet_base.pth # 修正した学習済みモデルはここに保存する └─ util └─ dataloader.py
1. ファインチューニングの実装方針
PSPNetを学習させるにあたり、ゼロからモデルの学習はさせずに、学習済みモデルを使用してファインチューニングを実施していく方針としています。
実装の参考にしているGithubでは、PASCAL VOC 2012、ADE20K、Cityscapesのデータで学習された3つのモデルが公開されています。
ファインチューニングで使用する学習済みモデルは、こちらで公開されているモデルを使用します。
ここで注意する点として使用する学習済みモデルをどれにするかです。
今回、学習に使用するデータはPASCAL VOC 2012のデータになります。
なので、使用する学習済みモデルはADE20Kのモデルを使用します。
(PASCAL VOC 2012で学習させたモデルを使用して、PASCAL VOC 2012のデータを学習させてはファインチューニングにはなりません。)
最後にファインチューニングの実装の流れとして次のような流れになります。
- ADE20Kの学習済みモデルをダウンロード
- ADE20Kのモデルを読み込む
- 読み込んだADE20Kのモデルの出力層をPASCAL VOC 2012のデータに対応するように修正
- PASCAL VOC 2012のデータでファインチューニング
以上がファインチューニングの実装方針になります。
2. 実行環境について
実行環境としてGoogle Colaboratoryを想定しています。
GPUが無料で使用できるとはいえ、メモリ等の制約条件が存在します。
バッチサイズを小さくして、メモリエラーに備える工夫をしています。
また、Google Colaboratoryは使用量の上限が存在するため、上限を超えると切断されてしまいます。
(他にも90分ルールや12時間ルールなど存在します。)
そのため1エポックごとにモデルを保存するようにしています。
切断されたら学習中の保存されたモデルを読み込んでこまめに再実行できるようにするためです。
3. モデルの準備
実際にファインチューニングを実施する前に、ADE20Kデータセットのダウンロードとモデルを読み込むための準備を行います。
3.1. 学習済みモデルのダウンロード
学習済みのPSPNetのパラメータをダウンロードします。
ダウンロード先は以下のURLです。train_epoch_100.pth
をダウンロードしてください。
ダウンロードしたら、parameters/trained_model
の下に格納してください。
├─ data | └─ VOCdevkit | └─ VOC2012 | ├─ Annotations | ├─ ImageSets | ├─ JPEGImages | ├─ SegmentationClass | └─ SegmentationObject ├─ initmodel | └─ resnet50_v2.pth ├─ model | ├─ pspnet.py | └─ resnet.py ├─ parameters | ├─ checkpoint | └─ trained_model | └─ train_epoch_100.pth # ← ダウンロードした学習済みモデルはここに保存する └─ util └─ dataloader.py
3.2. パラメータ内のキー名の変更
PyTorchでモデルのパラメータは、OrderedDict
形式で層の名前をKeyにしてパラメータが格納されています。
こんな感じで格納されています。
OrderedDict([('layer0.0.weight', tensor([[[[-3.6252e-02, -9.5168e-02, 1.9156e-01], [-2.2901e-02, -1.0108e-01, 2.5332e-01], [ 1.0695e-01, -1.7415e-01, -1.8174e-01]], ・・・・・ [[ 4.6924e-08, 2.6470e-08, 4.8394e-08], [ 4.2869e-08, -4.7989e-08, -3.0590e-08], [ 7.6188e-08, 2.5578e-08, 3.9737e-08]]]])), ('layer0.1.weight', tensor([3.1897e-01, 3.6042e-01, 1.9768e-06, 2.5844e-06, 6.3752e-05, 1.2454e-01, 5.1326e-01, 2.6607e-08, 2.9555e-01, 2.4985e-01, 3.1859e-01, 5.1083e-01, ・・・・・
PyTorchで学習済みモデルを読み込む際には、OrderedDict
内のKey名(ここでは層の名前)と数が一致している必要があります。
実際に今回実装したPSPNetとダウンロードしたモデル内のKey名を確認してみます。
(ちなみに、モデルのパラメータはstate_dict()
で確認することができます。)
import sys sys.path.append('../') import torch from model.pspnet import PSPNet # 実装したPSPNetのKey名を確認する model = PSPNet(n_classes=150) print('実装したPSPNet') print('Keyの数 : ', len(model.state_dict().keys())) print('Key名 : ') print(model.state_dict().keys()) # ダウンロードしたPSPNetのKey名を確認する param_path = '../parameters/trained_model/train_epoch_100.pth' param_ade = torch.load(param_path, map_location=torch.device('cpu')) print('ダウンロードしたPSPNet') print('Keyの数 : ', len(param_ade['state_dict'].keys())) print('Key名') print(param_ade['state_dict'].keys())
# Output 実装したPSPNet Keyの数 : 370 Key名 : odict_keys(['layer0.0.weight', 'layer0.1.weight', 'layer0.1.bias', 'layer0.1.running_mean', ・・・・・ 'aux.1.running_mean', 'aux.1.running_var', 'aux.1.num_batches_tracked', 'aux.4.weight', 'aux.4.bias']) ダウンロードしたPSPNet Keyの数 : 370 Key名 : odict_keys(['module.layer0.0.weight', 'module.layer0.1.weight', 'module.layer0.1.bias', 'module.layer0.1.running_mean', ・・・・・ 'module.aux.1.running_mean', 'module.aux.1.running_var', 'module.aux.1.num_batches_tracked', 'module.aux.4.weight', 'module.aux.4.bias'])
出力してみると実装したPSPNetとダウンロードしたPSPNetでパラメータの数は同じですが、Key名が異なります。
ダウンロードしたPSPNetのパラメータのKey名を、実装したPSPNetのパラメータのKey名に変更する必要があります。
(変更しないとエラーが出力され、ダウンロードした学習済みのパラメータを読み込むことができません。)
変更方法はKey名を一つずつ変更し、新しいOrderedDict
を作成していきます。
import sys sys.path.append('../') from collections import OrderedDict import torch from model.pspnet import PSPNet # ダウンロードしたPSPNetのパラメータを読み込む param_path = '../parameters/trained_model/train_epoch_100.pth' param_ade = torch.load(param_path, map_location=torch.device('cpu')) # Key名を変更する param_list = [] ppm_name_list = [ 'conv1', 'bn1', 'bn1', 'bn1', 'bn1', 'bn1', 'conv2', 'bn2', 'bn2', 'bn2', 'bn2', 'bn2', 'conv3', 'bn3', 'bn3', 'bn3', 'bn3', 'bn3', 'conv6', 'bn6', 'bn6', 'bn6', 'bn6', 'bn6', ] for before_key in param_ade['state_dict'].keys(): after_key = before_key.replace('module.', '').replace('features.', '') if 'ppm' in after_key: after_key = after_key.replace('ppm.', 'ppm.{}.'.format(ppm_name_list.pop(0))) after_key = after_key.replace('.0', '').replace('.1', '').replace('.2', '').replace('.3', '') param_list.append((after_key, param_ade['state_dict'][before_key])) param_ade_rename = OrderedDict(param_list)
これでパラメータのKey名を変更することができました。
実装したPSPNetに読み込めるかを確認します。
import sys sys.path.append('../') import torch from model.pspnet import PSPNet # 実装したPSPNet model = PSPNet(n_classes=150) # Key名を変更したパラメータを読み込む model.load_state_dict(param_ade_rename)
# Output <All keys matched successfully>
<All keys matched successfully>
と出力されれば、パラメータが正常に読み込むことができています。
ファインチューニング時に、Key名を変更したモデルをベースに学習を進めていくためモデルのパラメータを保存します。
次の処理でモデルのパラメータを保存することができます。
torch.save(model.state_dict(), '../parameters/trained_model/pspnet_base.pth')
以上がモデルの前準備になります。
4. まとめ
今回の記事ではファインチューニング実施前に、ファインチューニングの実装方針とモデルの前準備を行いました。
次の記事では、実際にファインチューニングを実施していきます。
最後に全体のコードを載せておきます。
5. 参考文献
論文
-
この論文の内容を紹介しています
-
-
https://github.com/hszhao/semseg
実装方法については、こちらのGithubのコードを参考にしています。
書籍
つくりながら学ぶ!PyTorchによる発展ディープラーニング
ファインチューニングの方法はこちらの書籍を参考にしています。
6. 全体コード
# notebool/01_ファインチューニング_準備偏.ipynb ### ライブラリ import sys sys.path.append('../') import torch from model.pspnet import PSPNet from collections import OrderedDict ### 中身の確認 # 実装したPSPNetのKey名を確認する model = PSPNet(n_classes=150) print('実装したPSPNet') print('Keyの数 : ', len(model.state_dict().keys())) print('Key名 : ') print(model.state_dict().keys()) # ダウンロードしたPSPNetのKey名を確認する param_path = '../parameters/trained_model/train_epoch_100.pth' param_ade = torch.load(param_path, map_location=torch.device('cpu')) print('ダウンロードしたPSPNet') print('Keyの数 : ', len(param_ade['state_dict'].keys())) print('Key名') print(param_ade['state_dict'].keys()) # Key名が一致しないとエラーが出力される # model.load_state_dict(param_ade) ### Key名の変更 # パラメータ内のキーを変更 param_list = [] ppm_name_list = [ 'conv1', 'bn1', 'bn1', 'bn1', 'bn1', 'bn1', 'conv2', 'bn2', 'bn2', 'bn2', 'bn2', 'bn2', 'conv3', 'bn3', 'bn3', 'bn3', 'bn3', 'bn3', 'conv6', 'bn6', 'bn6', 'bn6', 'bn6', 'bn6', ] for before_key in param_ade['state_dict'].keys(): after_key = before_key.replace('module.', '').replace('features.', '') if 'ppm' in after_key: after_key = after_key.replace('ppm.', 'ppm.{}.'.format(ppm_name_list.pop(0))) after_key = after_key.replace('.0', '').replace('.1', '').replace('.2', '').replace('.3', '') param_list.append((after_key, param_ade['state_dict'][before_key])) param_ade_rename = OrderedDict(param_list) ### 変更結果の確認 model = PSPNet(n_classes=150) model.load_state_dict(param_ade_rename) ### モデルの保存 torch.save(model.state_dict(), '../parameters/trained_model/pspnet_base.pth')