PSPNetを実装するにあたり、この記事ではUpSampling モジュールを実装していきます。
実装については、次の実装を参考にしています。
ディレクトリ構成は以下のような構成になっています。
model
のpspnet.py
の中にUpSampling モジュールを実装していきます。
├─ data | └─ VOCdevkit | └─ VOC2012 | ├─ Annotations | ├─ ImageSets | ├─ JPEGImages | ├─ SegmentationClass | └─ SegmentationObject ├─ initmodel | └─ resnet50_v2.pth ├─ model | ├─ pspnet.py # ここにUpSampling モジュールを追記していく | └─ resnet.py └─ util └─ dataloader.py
次に今回使用するライブラリです。
import torch import torch.nn as nn import torch.nn.functional as F
最後に全体のコードも載せているので、実際に動かしながら確認していただければと思います。
1. UpSamplingモジュールの役割
Pyramid Pooling モジュールの出力をさらに畳み込みを行い入力画像と同じ大きさに変換します。
Pyramid Pooling モジュールの出力は (4096×60×60)なので、UpSamplingモジュールでは( 21(クラス数)×475(高さ)×475(幅) )に変換するということです。
次にUpSamplingモジュールの動きについて解説します。
以下の図を使って全体の動きを解説します。
処理の流れとして、Pyramid Pooling モジュールからの出力を入力にします。この時の入力のサイズは(4096×60×60)になります。
これを畳み込みフィルタを通して、(21×60×60)に変換します。
正解データとの誤差を計算するために入力画像と同じ大きさに拡大します。
この一連の処理をすることで、最終的な出力は(クラス数×高さ×幅)になります。
以上が、UpSamplingモジュールの役割と動きの解説になります。
2. PSPNetの実装
PSPNetにUpSamplingモジュールを追加していきます。
やることとしては畳み込みフィルとと画像拡大処理を追加するだけです。
ここではUpSamplingモジュールの追加のみですので、PSPNetの実装はまだ先に続きますのでご注意ください。
import sys sys.path.append('../') import torch import torch.nn as nn import torch.nn.functional as F from model.resnet import resnet50 class PSPNet(nn.Module): def __init__(self): super(PSPNet, self).__init__() # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する resnet = resnet50(pretrained=True) self.layer0 = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool ) self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 # layer3とlayer4の畳み込みフィルタのパラメータを変更する for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) # Pyramid Pooling モジュール self.ppm = PyramidPoolingModule() # UpSampling モジュール self.cls = nn.Sequential( nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Dropout(0.1), nn.Conv2d(512, 21, kernel_size=1) ) def forward(self, x): x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) # AuxLossのためにlayer3から出力を抜き出しておく x_tmp = self.layer3(x) x = self.layer4(x_tmp) # Pyramid Pooling モジュール x = self.ppm(x) # UpSampling モジュール x = self.cls(x) # 入力画像と同じ大きさに変換する x = F.interpolate(x, size=(475, 475), mode='bilinear', align_corners=True) return x
動作確認として、実際にモデルを動かしてみます。
# ダミーデータの作成 input = torch.rand(4, 3, 475, 475) model = PSPNet() model.eval() output = model(input) print(output.shape)
# Output torch.Size([4, 21, 475, 475])
想定通り、torch.size([バッチサイズ, 21, 475, 475])
の形で出力されています。
これで、UpSamplingモジュールの実装が完了です。
3. まとめ
今回の記事ではPSPNet内にUpSamplingモジュールを追加しました。
これだけではPSPNetは完成でないなので、さらに実装を続けていきます。
最後に全体のコードを載せておきます。
4. 参考
論文
-
この論文の内容を紹介しています
-
-
https://github.com/hszhao/semseg
実装方法については、こちらのGithubのコードを参考にしています。
書籍
つくりながら学ぶ!PyTorchによる発展ディープラーニング
ファインチューニングの方法はこちらの書籍を参考にしています。
5. 全体コード
# model.pspnet.py ### ライブラリ import sys sys.path.append('../') import torch import torch.nn as nn import torch.nn.functional as F from model.resnet import resnet50 ### クラス定義 class PyramidPoolingModule(nn.Module): def __init__(self): super(PyramidPoolingModule, self).__init__() # 1×1スケール self.avgpool1 = nn.AdaptiveAvgPool2d(output_size=1) self.conv1 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(512) self.relu1 = nn.ReLU(inplace=True) # 2×2スケール self.avgpool2 = nn.AdaptiveAvgPool2d(output_size=2) self.conv2 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(512) self.relu2 = nn.ReLU(inplace=True) # 3×3スケール self.avgpool3 = nn.AdaptiveAvgPool2d(output_size=3) self.conv3 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(512) self.relu3 = nn.ReLU(inplace=True) # 4×4スケール self.avgpool6 = nn.AdaptiveAvgPool2d(output_size=6) self.conv6 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn6 = nn.BatchNorm2d(512) self.relu6 = nn.ReLU(inplace=True) def forward(self, x): # 1×1スケール out1 = self.avgpool1(x) out1 = self.conv1(out1) out1 = self.bn1(out1) out1 = self.relu1(out1) out1 = F.interpolate(out1, (60, 60), mode='bilinear', align_corners=True) # 2×2スケール out2 = self.avgpool2(x) out2 = self.conv2(out2) out2 = self.bn2(out2) out2 = self.relu2(out2) out2 = F.interpolate(out2, (60, 60), mode='bilinear', align_corners=True) # 3×3スケール out3 = self.avgpool3(x) out3 = self.conv3(out3) out3 = self.bn3(out3) out3 = self.relu3(out3) out3 = F.interpolate(out3, (60, 60), mode='bilinear', align_corners=True) # 6×6スケール out6 = self.avgpool6(x) out6 = self.conv6(out6) out6 = self.bn6(out6) out6 = self.relu6(out6) out6 = F.interpolate(out6, (60, 60), mode='bilinear', align_corners=True) # 元の入力と各スケールの特徴量を結合させる out = torch.cat([x, out1, out2, out3, out6], dim=1) return out class PSPNet(nn.Module): def __init__(self): super(PSPNet, self).__init__() # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する resnet = resnet50(pretrained=True) self.layer0 = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool ) self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 # layer3とlayer4の畳み込みフィルタのパラメータを変更する for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) # Pyramid Pooling モジュール self.ppm = PyramidPoolingModule() # UpSampling モジュール self.cls = nn.Sequential( nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Dropout(0.1), nn.Conv2d(512, 21, kernel_size=1) ) def forward(self, x): x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) # AuxLossのためにlayer3から出力を抜き出しておく x_tmp = self.layer3(x) x = self.layer4(x_tmp) # Pyramid Pooling モジュール x = self.ppm(x) # UpSampling モジュール x = self.cls(x) # 入力画像と同じ大きさに変換する x = F.interpolate(x, size=(475, 475), mode='bilinear', align_corners=True) return x # 動作確認 if __name__ == '__main__': input = torch.rand(4, 3, 475, 475) model = PSPNet() model.eval() output = model(input) print(output.shape)