PSPNetを実装するにあたり、この記事ではAuxLossモジュールを実装していきます。
実装については、次の実装を参考にしています。
ディレクトリ構成は以下のような構成になっています。
model
のpspnet.py
の中にAuxLossモジュールを実装していきます。
├─ data | └─ VOCdevkit | └─ VOC2012 | ├─ Annotations | ├─ ImageSets | ├─ JPEGImages | ├─ SegmentationClass | └─ SegmentationObject ├─ initmodel | └─ resnet50_v2.pth ├─ model | ├─ pspnet.py # ここにAuxLossモジュールを追記していく | └─ resnet.py └─ util └─ dataloader.py
次に今回使用するライブラリです。
import torch import torch.nn as nn import torch.nn.functional as F
最後に全体のコードも載せているので、実際に動かしながら確認していただければと思います。
1. AuxLossモジュールの役割
Feature Map モジュール、Pyramid Pooling モジュール、UpSampling モジュールの3つのモジュールでセグメンテーションモデルを実現することはできますが、AuxLoss モジュールを追加することで、損失関数の計算補助を行います。
Feature Map モジュールの途中からとってきた出力を、UpSampling モジュールと同様に入力画像のサイズと同じ大きさに変換します。
AuxLoss モジュールとUpSampling モジュールの二つの出力を、アノテーションデータと対応させて損失値を計算し、バックプロパゲーションを行うことでAuxLossモジュールが計算の補助を行うというわけです。
次にAuxLossモジュールの動きについて解説します。
以下の図を使って全体の動きを解説します。
まずはメインとなるのは、FatureMapモジュール、Pyramid Poolingモジュール、UpSamplingモジュールを通るルートになります。
これがメインとなるルートで最終的な出力とアノテーションデータとの損失(Loss 1)を計算します。
AuxLossモジュールは、FeatureMapモジュールのLayer3を途中で抜き出し、分岐させる形でAuxLossモジュールに渡します。
FeatureMapモジュールの途中の出力を渡したら畳み込みを行い入力画像と同じ大きさに画像を拡大させます。
このAuxLossモジュールからの出力とアノテーションデータとの損失(Loss 2)を計算します。
当然、FeatureMapモジュールの途中までの出力を使用するので分類精度としては低くなります。
ディープラーニングではネットワークが深くなればなるほど学習が困難になる傾向があり、ResNetもネットワークが深いので学習が難しいモデルになります。
そのためAuxLossモジュールではFeatureMapモジュールの途中の出力を抜き出して損失を計算することによって、ResNetの前半の層(Layer3より前)の学習を補助する役割を担います。
学習時にはメインルートの損失(Loss 1)で全体を最適化しつつ、AuxLossモジュールの損失(Loss 2)でResNetの前半の層(Layer3より前)を補助しながら学習を進めます。
以上が、AuxLossモジュールの役割と動きの解説になります。
2. PSPNetの実装
2.1. 実装上の注意点:学習時と推論時の動作を制御する
AuxLossモジュールはあくまでも学習時に補助的に使用するため、学習時には使用しますが推論時には使用しません。
つまり、PSPNetでは学習時または推論時で動作を変える必要があります。
やり方としては簡単で、nn.Module
を継承させたクラスを作成すると、変数としてtraining
を持っています。
これは学習時であればTrue
を返し、推論時であればFalse
を返します。
この変数を使用して、学習時にはAuxLossモジュールを使用し、推論時にはAuxLossモジュールを使用しないように動作を制御します。
こんな風に書くだけで制御できます。
# 学習時かどうかはクラスの変数として持っているので、self.trainingを参照するだけでOK if self.training is True: self.aux = nn.Sequential( nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.1), nn.Conv2d(256, 21, kernel_size=21) )
ちなみに、PyTorchで作成したモデルを学習か推論かを指定するには次のように設定します。
model = PSPNet() # 学習用に設定 model.train() # 推論用に設定 model.eval()
2.2. 実装
PSPNetにAuxLossモジュールを追加していきます。
やることとしてはLayer3から出力を抜き出してAuxLossモジュールに渡しやることと、学習時のみAuxLossモジュールを使用するように制御することです。
加えて、モデル内のパラメータはわかりやすさを重視して今までべた書きしていましたが、今後の学習を見据えて分類するクラス数を指定できるように変更しておきます。
import sys sys.path.append('../') import torch import torch.nn as nn import torch.nn.functional as F from model.resnet import resnet50 class PSPNet(nn.Module): def __init__(self, n_classes=21): super(PSPNet, self).__init__() self.n_classes = n_classes # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する resnet = resnet50(pretrained=True) self.layer0 = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool ) self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 # layer3とlayer4の畳み込みフィルタのパラメータを変更する for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) # Pyramid Pooling モジュール self.ppm = PyramidPoolingModule() # UpSampling モジュール self.cls = nn.Sequential( nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Dropout(0.1), nn.Conv2d(512, self.n_classes, kernel_size=1) ) # 学習時にのみAuxLossモジュールを使用するように設定 if self.training is True: self.aux = nn.Sequential( nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.1), nn.Conv2d(256, self.n_classes, kernel_size=1) ) def forward(self, x, y=None): x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) # AuxLossのためにlayer3から出力を抜き出しておく x_tmp = self.layer3(x) x = self.layer4(x_tmp) # Pyramid Pooling モジュール x = self.ppm(x) # UpSampling モジュール x = self.cls(x) # 入力画像と同じ大きさに変換する x = F.interpolate(x, size=(475, 475), mode='bilinear', align_corners=True) # 学習時にのみAuxLossモジュールを使用するように設定 if self.training is True: aux = self.aux(x_tmp) aux = F.interpolate(aux, size=(475, 475), mode='bilinear', align_corners=True) return x, aux return x
動作確認として、実際にモデルを動かしてみます。
# ダミーデータの作成 input = torch.rand(4, 3, 475, 475) model = PSPNet() # モデルを学習用に設定した場合 model.train() output, aux = model(input) print('学習時') print(output.shape) print(aux.shape) # モデルを推論用に設定した場合 model.eval() output = model(input) print('推論時') print(output.shape)
# Output 学習時 torch.Size([4, 21, 475, 475]) torch.Size([4, 21, 475, 475]) 推論時 torch.Size([4, 21, 475, 475])
想定通り、学習時には二つの出力が返りtorch.size([バッチサイズ, 21, 475, 475])
の形で出力されています。
一方で推論時には一つの出力が返りtorch.size([バッチサイズ, 21, 475, 475])
の形で出力されています。
これで、AuxLossモジュールの実装が完了です。
3. まとめ
今回の記事ではPSPNet内にAuxLossモジュールを追加しました。
これでPSPNetは完成になります。次の記事からは実装したモデルを学習させていきます。
最後に全体のコードを載せておきます。
4. 参考文献
論文
-
この論文の内容を紹介しています
-
-
https://github.com/hszhao/semseg
実装方法については、こちらのGithubのコードを参考にしています。
書籍
つくりながら学ぶ!PyTorchによる発展ディープラーニング
ファインチューニングの方法はこちらの書籍を参考にしています。
5. 全体コード
# model.pspnet.py ### ライブラリ import sys from torch.nn.modules import padding sys.path.append('../') import torch import torch.nn as nn import torch.nn.functional as F from model.resnet import resnet50 ### クラス定義 class PyramidPoolingModule(nn.Module): def __init__(self): super(PyramidPoolingModule, self).__init__() # 1×1スケール self.avgpool1 = nn.AdaptiveAvgPool2d(output_size=1) self.conv1 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(512) self.relu1 = nn.ReLU(inplace=True) # 2×2スケール self.avgpool2 = nn.AdaptiveAvgPool2d(output_size=2) self.conv2 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(512) self.relu2 = nn.ReLU(inplace=True) # 3×3スケール self.avgpool3 = nn.AdaptiveAvgPool2d(output_size=3) self.conv3 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(512) self.relu3 = nn.ReLU(inplace=True) # 4×4スケール self.avgpool6 = nn.AdaptiveAvgPool2d(output_size=6) self.conv6 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn6 = nn.BatchNorm2d(512) self.relu6 = nn.ReLU(inplace=True) def forward(self, x): # 1×1スケール out1 = self.avgpool1(x) out1 = self.conv1(out1) out1 = self.bn1(out1) out1 = self.relu1(out1) out1 = F.interpolate(out1, (60, 60), mode='bilinear', align_corners=True) # 2×2スケール out2 = self.avgpool2(x) out2 = self.conv2(out2) out2 = self.bn2(out2) out2 = self.relu2(out2) out2 = F.interpolate(out2, (60, 60), mode='bilinear', align_corners=True) # 3×3スケール out3 = self.avgpool3(x) out3 = self.conv3(out3) out3 = self.bn3(out3) out3 = self.relu3(out3) out3 = F.interpolate(out3, (60, 60), mode='bilinear', align_corners=True) # 6×6スケール out6 = self.avgpool6(x) out6 = self.conv6(out6) out6 = self.bn6(out6) out6 = self.relu6(out6) out6 = F.interpolate(out6, (60, 60), mode='bilinear', align_corners=True) # 元の入力と各スケールの特徴量を結合させる out = torch.cat([x, out1, out2, out3, out6], dim=1) return out class PSPNet(nn.Module): def __init__(self, n_classes=21): super(PSPNet, self).__init__() self.n_classes = n_classes # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する resnet = resnet50(pretrained=True) self.layer0 = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool ) self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 # layer3とlayer4の畳み込みフィルタのパラメータを変更する for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) # Pyramid Pooling モジュール self.ppm = PyramidPoolingModule() # UpSampling モジュール self.cls = nn.Sequential( nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Dropout(0.1), nn.Conv2d(512, self.n_classes, kernel_size=1) ) # 学習時にのみAuxLossモジュールを使用するように設定 if self.training is True: self.aux = nn.Sequential( nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.1), nn.Conv2d(256, self.n_classes, kernel_size=1) ) def forward(self, x, y=None): x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) # AuxLossのためにlayer3から出力を抜き出しておく x_tmp = self.layer3(x) x = self.layer4(x_tmp) # Pyramid Pooling モジュール x = self.ppm(x) # UpSampling モジュール x = self.cls(x) # 入力画像と同じ大きさに変換する x = F.interpolate(x, size=(475, 475), mode='bilinear', align_corners=True) # 学習時にのみAuxLossモジュールを使用するように設定 if self.training is True: aux = self.aux(x_tmp) aux = F.interpolate(aux, size=(475, 475), mode='bilinear', align_corners=True) return x, aux return x # 動作確認 if __name__ == '__main__': input = torch.rand(4, 3, 475, 475) model = PSPNet() # モデルを学習用に設定した場合 model.train() output, aux = model(input) print('学習時') print(output.shape) print(aux.shape) # モデルを推論用に設定した場合 model.eval() output = model(input) print('推論時') print(output.shape)