PSPNetを実装するにあたり、この記事ではPyramid Pooling モジュールを実装していきます。
実装については、次の実装を参考にしています。
ディレクトリ構成は以下のような構成になっています。
model
のpspnet.py
の中にPyramid Pooling モジュールを実装していきます。
├─ data | └─ VOCdevkit | └─ VOC2012 | ├─ Annotations | ├─ ImageSets | ├─ JPEGImages | ├─ SegmentationClass | └─ SegmentationObject ├─ initmodel | └─ resnet50_v2.pth ├─ model | ├─ pspnet.py # ここにPyramid Pooling モジュールを追記していく | └─ resnet.py └─ util └─ dataloader.py
次に今回使用するライブラリです。
import torch import torch.nn as nn import torch.nn.functional as F
最後に全体のコードも載せているので、実際に動かしながら確認していただければと思います。
1. Pyramid Pooling モジュールの役割
Feature Map モジュールで抽出した特徴量を、Poolingを行い様々なスケールに変換します。
論文中では1×1、2×2、3×3、6×6のスケールの層を作成しています。
異なるスケールの特徴量を作成することで、全体を考慮するまたは部分的に考慮するなどといったPSPNetの強みであるピクセルの周辺情報を考慮した特徴量を作成することができる。
次にPyramid Pooling モジュールの動きについて解説します。
以下の図を使って全体の動きを解説します。
Pyramid Pooling モジュールの入力は、Feature Mapモジュールからの出力になります。この時の入力のサイズは(2048 × 60 × 60)になります。
1段目では受け取った(2048×60×60)の入力に対して、それぞれ1×1、2×2、3×3、6×6のプーリング処理を行います。4つに分岐させることで異なるスケールの特徴量を作成します。
2段目ではさらに畳み込みを行い2048個のチャンネル数を512個に減らします。
3段目では画像拡大処理を行います。画像拡大前は高さと幅がそれぞれ1×1、2×2、3×3、6×6なので、すべて60×60に拡大します。
最後に、それぞれのスケールから出力された特徴量を元の特徴量も含めてすべて結合させます。
元の特徴量のサイズが(2048×60×60)で、それぞれのスケールが(512×60×60)の4つなので、最終的な出力は(4096×60×60)になります。
以上が、Pyramid Pooling モジュールの役割と動きの解説になります。
2. Pyramid Pooing モジュールの実装
分岐させると聞くと難しく聞こえるかもしれませんが、すごく簡単です。
プーリング処理はnn.AdaptiveAvgPool2d
で行い、画像拡大処理はF.interpolate
で行います。
nn.Module
を継承させたPyramidPoolingModuleクラスを作成していきます。
import torch import torch.nn as nn import torch.nn.functional as F class PyramidPoolingModule(nn.Module): def __init__(self): super(PyramidPoolingModule, self).__init__() # 1×1スケール self.avgpool1 = nn.AdaptiveAvgPool2d(output_size=1) self.conv1 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(512) self.relu1 = nn.ReLU(inplace=True) # 2×2スケール self.avgpool2 = nn.AdaptiveAvgPool2d(output_size=2) self.conv2 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(512) self.relu2 = nn.ReLU(inplace=True) # 3×3スケール self.avgpool3 = nn.AdaptiveAvgPool2d(output_size=3) self.conv3 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(512) self.relu3 = nn.ReLU(inplace=True) # 4×4スケール self.avgpool6 = nn.AdaptiveAvgPool2d(output_size=6) self.conv6 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn6 = nn.BatchNorm2d(512) self.relu6 = nn.ReLU(inplace=True) def forward(self, x): # 1×1スケール out1 = self.avgpool1(x) out1 = self.conv1(out1) out1 = self.bn1(out1) out1 = self.relu1(out1) out1 = F.interpolate(out1, (60, 60), mode='bilinear', align_corners=True) # 2×2スケール out2 = self.avgpool2(x) out2 = self.conv2(out2) out2 = self.bn2(out2) out2 = self.relu2(out2) out2 = F.interpolate(out2, (60, 60), mode='bilinear', align_corners=True) # 3×3スケール out3 = self.avgpool3(x) out3 = self.conv3(out3) out3 = self.bn3(out3) out3 = self.relu3(out3) out3 = F.interpolate(out3, (60, 60), mode='bilinear', align_corners=True) # 6×6スケール out6 = self.avgpool6(x) out6 = self.conv6(out6) out6 = self.bn6(out6) out6 = self.relu6(out6) out6 = F.interpolate(out6, (60, 60), mode='bilinear', align_corners=True) # 元の入力と各スケールの特徴量を結合させる out = torch.cat([x, out1, out2, out3, out6], dim=1) return out
動作確認として、実際にPyramid Pooling モジュールを動かしてみます。
# ダミーデータ作成 input = torch.rand(4, 2048, 60, 60) model = PyramidPoolingModule() output = model(input) print(output.shape)
# Output torch.Size([4, 4096, 60, 60])
想定通り、torch.Size([バッチサイズ, 4096, 60, 60])
の形で出力されています。
3. PSPNetの実装
PSPNetにPyramid Pooling モジュールを追加していきます。
ここではPyramid Pooling モジュールの追加のみですので、PSPNetの実装はまだ先に続きますのでご注意ください。
import sys sys.path.append('../') import torch import torch.nn as nn import torch.nn.functional as F from model.resnet import resnet50 class PSPNet(nn.Module): def __init__(self): super(PSPNet, self).__init__() # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する resnet = resnet50(pretrained=True) self.layer0 = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool ) self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 # layer3とlayer4の畳み込みフィルタのパラメータを変更する for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) # Pyramid Pooling モジュール self.ppm = PyramidPoolingModule() def forward(self, x): x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) # AuxLossのためにlayer3から出力を抜き出しておく x_tmp = self.layer3(x) x = self.layer4(x_tmp) # Pyramid Pooling モジュール x = self.ppm(x) return x
動作確認として、実際にモデルを動かしてみます。
# ダミーデータの作成 input = torch.rand(4, 3, 475, 475) model = PSPNet() model.eval() output = model(input) print(output.shape)
# Output torch.Size([4, 2048, 60, 60])
想定通り、torch.size([バッチサイズ, 2048, 60, 60])
の形で出力されています。
これで、Pyramid Pooling モジュールの実装が完了です。
4. まとめ
今回の記事ではPyramid Pooling モジュールの実装と、PSPNet内にPyramid Pooling モジュールを追加しました。
これだけではPSPNetは完成でないなので、さらに実装を続けていきます。
最後に全体のコードを載せておきます。
5. 参考文献
論文
-
この論文の内容を紹介しています
-
-
https://github.com/hszhao/semseg
実装方法については、こちらのGithubのコードを参考にしています。
書籍
つくりながら学ぶ!PyTorchによる発展ディープラーニング
ファインチューニングの方法はこちらの書籍を参考にしています。
6. 全体コード
# model.pspnet.py ### ライブラリ import sys sys.path.append('../') import torch import torch.nn as nn import torch.nn.functional as F from model.resnet import resnet50 ### クラス定義 class PyramidPoolingModule(nn.Module): def __init__(self): super(PyramidPoolingModule, self).__init__() # 1×1スケール self.avgpool1 = nn.AdaptiveAvgPool2d(output_size=1) self.conv1 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(512) self.relu1 = nn.ReLU(inplace=True) # 2×2スケール self.avgpool2 = nn.AdaptiveAvgPool2d(output_size=2) self.conv2 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(512) self.relu2 = nn.ReLU(inplace=True) # 3×3スケール self.avgpool3 = nn.AdaptiveAvgPool2d(output_size=3) self.conv3 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(512) self.relu3 = nn.ReLU(inplace=True) # 4×4スケール self.avgpool6 = nn.AdaptiveAvgPool2d(output_size=6) self.conv6 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) self.bn6 = nn.BatchNorm2d(512) self.relu6 = nn.ReLU(inplace=True) def forward(self, x): # 1×1スケール out1 = self.avgpool1(x) out1 = self.conv1(out1) out1 = self.bn1(out1) out1 = self.relu1(out1) out1 = F.interpolate(out1, (60, 60), mode='bilinear', align_corners=True) # 2×2スケール out2 = self.avgpool2(x) out2 = self.conv2(out2) out2 = self.bn2(out2) out2 = self.relu2(out2) out2 = F.interpolate(out2, (60, 60), mode='bilinear', align_corners=True) # 3×3スケール out3 = self.avgpool3(x) out3 = self.conv3(out3) out3 = self.bn3(out3) out3 = self.relu3(out3) out3 = F.interpolate(out3, (60, 60), mode='bilinear', align_corners=True) # 6×6スケール out6 = self.avgpool6(x) out6 = self.conv6(out6) out6 = self.bn6(out6) out6 = self.relu6(out6) out6 = F.interpolate(out6, (60, 60), mode='bilinear', align_corners=True) # 元の入力と各スケールの特徴量を結合させる out = torch.cat([x, out1, out2, out3, out6], dim=1) return out class PSPNet(nn.Module): def __init__(self): super(PSPNet, self).__init__() # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する resnet = resnet50(pretrained=True) self.layer0 = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool ) self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 # layer3とlayer4の畳み込みフィルタのパラメータを変更する for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) # Pyramid Pooling モジュール self.ppm = PyramidPoolingModule() def forward(self, x): x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) # AuxLossのためにlayer3から出力を抜き出しておく x_tmp = self.layer3(x) x = self.layer4(x_tmp) # Pyramid Pooling モジュール x = self.ppm(x) return x # 動作確認 if __name__ == '__main__': input = torch.rand(4, 3, 475, 475) model = PSPNet() model.eval() output = model(input) print(output.shape)