PyTorchで実装するPSPNet モジュール実装 Pyramid Pooling モジュール偏 ⑦

PSPNetを実装するにあたり、この記事ではPyramid Pooling モジュールを実装していきます。

実装については、次の実装を参考にしています。

github.com

ディレクトリ構成は以下のような構成になっています。

modelpspnet.pyの中にPyramid Pooling モジュールを実装していきます。

├─ data
|   └─ VOCdevkit
|      └─ VOC2012
|           ├─ Annotations
|           ├─ ImageSets
|           ├─ JPEGImages
|           ├─ SegmentationClass
|           └─ SegmentationObject
├─ initmodel
|   └─ resnet50_v2.pth
├─ model
|   ├─ pspnet.py         # ここにPyramid Pooling モジュールを追記していく
|   └─ resnet.py         
└─ util
    └─ dataloader.py

次に今回使用するライブラリです。

import torch
import torch.nn as nn
import torch.nn.functional as F

最後に全体のコードも載せているので、実際に動かしながら確認していただければと思います。




1. Pyramid Pooling モジュールの役割

f:id:venoda:20210815180958p:plain

Feature Map モジュールで抽出した特徴量を、Poolingを行い様々なスケールに変換します。

論文中では1×1、2×2、3×3、6×6のスケールの層を作成しています。

異なるスケールの特徴量を作成することで、全体を考慮するまたは部分的に考慮するなどといったPSPNetの強みであるピクセルの周辺情報を考慮した特徴量を作成することができる。

次にPyramid Pooling モジュールの動きについて解説します。

以下の図を使って全体の動きを解説します。

f:id:venoda:20210815181011p:plain

Pyramid Pooling モジュールの入力は、Feature Mapモジュールからの出力になります。この時の入力のサイズは(2048 × 60 × 60)になります。

1段目では受け取った(2048×60×60)の入力に対して、それぞれ1×1、2×2、3×3、6×6のプーリング処理を行います。4つに分岐させることで異なるスケールの特徴量を作成します。

2段目ではさらに畳み込みを行い2048個のチャンネル数を512個に減らします。

3段目では画像拡大処理を行います。画像拡大前は高さと幅がそれぞれ1×1、2×2、3×3、6×6なので、すべて60×60に拡大します。

最後に、それぞれのスケールから出力された特徴量を元の特徴量も含めてすべて結合させます。

元の特徴量のサイズが(2048×60×60)で、それぞれのスケールが(512×60×60)の4つなので、最終的な出力は(4096×60×60)になります。

以上が、Pyramid Pooling モジュールの役割と動きの解説になります。




2. Pyramid Pooing モジュールの実装

分岐させると聞くと難しく聞こえるかもしれませんが、すごく簡単です。

プーリング処理はnn.AdaptiveAvgPool2dで行い、画像拡大処理はF.interpolateで行います。

nn.Moduleを継承させたPyramidPoolingModuleクラスを作成していきます。

import torch
import torch.nn as nn
import torch.nn.functional as F


class PyramidPoolingModule(nn.Module):
    def __init__(self):
        super(PyramidPoolingModule, self).__init__()
        
        # 1×1スケール
        self.avgpool1 = nn.AdaptiveAvgPool2d(output_size=1)
        self.conv1 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(512)
        self.relu1 = nn.ReLU(inplace=True)

        # 2×2スケール
        self.avgpool2 = nn.AdaptiveAvgPool2d(output_size=2)
        self.conv2 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(512)
        self.relu2 = nn.ReLU(inplace=True)
        
        # 3×3スケール
        self.avgpool3 = nn.AdaptiveAvgPool2d(output_size=3)
        self.conv3 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(512)
        self.relu3 = nn.ReLU(inplace=True)
        
        # 4×4スケール
        self.avgpool6 = nn.AdaptiveAvgPool2d(output_size=6)
        self.conv6 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn6 = nn.BatchNorm2d(512)
        self.relu6 = nn.ReLU(inplace=True)
        
    def forward(self, x):
        # 1×1スケール
        out1 = self.avgpool1(x)
        out1 = self.conv1(out1)
        out1 = self.bn1(out1)
        out1 = self.relu1(out1)
        out1 = F.interpolate(out1, (60, 60), mode='bilinear', align_corners=True)
        
        # 2×2スケール
        out2 = self.avgpool2(x)
        out2 = self.conv2(out2)
        out2 = self.bn2(out2)
        out2 = self.relu2(out2)
        out2 = F.interpolate(out2, (60, 60), mode='bilinear', align_corners=True)
        
        # 3×3スケール
        out3 = self.avgpool3(x)
        out3 = self.conv3(out3)
        out3 = self.bn3(out3)
        out3 = self.relu3(out3)
        out3 = F.interpolate(out3, (60, 60), mode='bilinear', align_corners=True)
        
        # 6×6スケール
        out6 = self.avgpool6(x)
        out6 = self.conv6(out6)
        out6 = self.bn6(out6)
        out6 = self.relu6(out6)
        out6 = F.interpolate(out6, (60, 60), mode='bilinear', align_corners=True)
        
        # 元の入力と各スケールの特徴量を結合させる
        out = torch.cat([x, out1, out2, out3, out6], dim=1)
        
        return out

動作確認として、実際にPyramid Pooling モジュールを動かしてみます。

# ダミーデータ作成
input = torch.rand(4, 2048, 60, 60)

model = PyramidPoolingModule()
output = model(input)

print(output.shape)
# Output
torch.Size([4, 4096, 60, 60])

想定通り、torch.Size([バッチサイズ, 4096, 60, 60])の形で出力されています。




3. PSPNetの実装

PSPNetにPyramid Pooling モジュールを追加していきます。

ここではPyramid Pooling モジュールの追加のみですので、PSPNetの実装はまだ先に続きますのでご注意ください。

import sys
sys.path.append('../')

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.resnet import resnet50


class PSPNet(nn.Module):
    
    def __init__(self):
        super(PSPNet, self).__init__()

        # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する
        resnet = resnet50(pretrained=True)
        self.layer0 = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, 
            resnet.conv2, resnet.bn2, resnet.relu, 
            resnet.conv3, resnet.bn3, resnet.relu, 
            resnet.maxpool
        )
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        # layer3とlayer4の畳み込みフィルタのパラメータを変更する
        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)

            elif 'downsample.0' in n:
                m.stride = (1, 1)

        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        
        # Pyramid Pooling モジュール
        self.ppm = PyramidPoolingModule()
            
    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        # AuxLossのためにlayer3から出力を抜き出しておく
        x_tmp = self.layer3(x)
        x = self.layer4(x_tmp)

        # Pyramid Pooling モジュール
        x = self.ppm(x)
        
        return x

動作確認として、実際にモデルを動かしてみます。

# ダミーデータの作成
input = torch.rand(4, 3, 475, 475)

model = PSPNet()
model.eval()
output = model(input)

print(output.shape)
# Output
torch.Size([4, 2048, 60, 60])

想定通り、torch.size([バッチサイズ, 2048, 60, 60])の形で出力されています。

これで、Pyramid Pooling モジュールの実装が完了です。




4. まとめ

今回の記事ではPyramid Pooling モジュールの実装と、PSPNet内にPyramid Pooling モジュールを追加しました。

これだけではPSPNetは完成でないなので、さらに実装を続けていきます。

最後に全体のコードを載せておきます。


5. 参考文献




6. 全体コード

# model.pspnet.py
### ライブラリ
import sys
sys.path.append('../')

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.resnet import resnet50


### クラス定義
class PyramidPoolingModule(nn.Module):
    def __init__(self):
        super(PyramidPoolingModule, self).__init__()
        
        # 1×1スケール
        self.avgpool1 = nn.AdaptiveAvgPool2d(output_size=1)
        self.conv1 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(512)
        self.relu1 = nn.ReLU(inplace=True)

        # 2×2スケール
        self.avgpool2 = nn.AdaptiveAvgPool2d(output_size=2)
        self.conv2 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(512)
        self.relu2 = nn.ReLU(inplace=True)
        
        # 3×3スケール
        self.avgpool3 = nn.AdaptiveAvgPool2d(output_size=3)
        self.conv3 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(512)
        self.relu3 = nn.ReLU(inplace=True)
        
        # 4×4スケール
        self.avgpool6 = nn.AdaptiveAvgPool2d(output_size=6)
        self.conv6 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn6 = nn.BatchNorm2d(512)
        self.relu6 = nn.ReLU(inplace=True)
        
    def forward(self, x):
        # 1×1スケール
        out1 = self.avgpool1(x)
        out1 = self.conv1(out1)
        out1 = self.bn1(out1)
        out1 = self.relu1(out1)
        out1 = F.interpolate(out1, (60, 60), mode='bilinear', align_corners=True)
        
        # 2×2スケール
        out2 = self.avgpool2(x)
        out2 = self.conv2(out2)
        out2 = self.bn2(out2)
        out2 = self.relu2(out2)
        out2 = F.interpolate(out2, (60, 60), mode='bilinear', align_corners=True)
        
        # 3×3スケール
        out3 = self.avgpool3(x)
        out3 = self.conv3(out3)
        out3 = self.bn3(out3)
        out3 = self.relu3(out3)
        out3 = F.interpolate(out3, (60, 60), mode='bilinear', align_corners=True)
        
        # 6×6スケール
        out6 = self.avgpool6(x)
        out6 = self.conv6(out6)
        out6 = self.bn6(out6)
        out6 = self.relu6(out6)
        out6 = F.interpolate(out6, (60, 60), mode='bilinear', align_corners=True)
        
        # 元の入力と各スケールの特徴量を結合させる
        out = torch.cat([x, out1, out2, out3, out6], dim=1)
        
        return out


class PSPNet(nn.Module):
    
    def __init__(self):
        super(PSPNet, self).__init__()

        # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する
        resnet = resnet50(pretrained=True)
        self.layer0 = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, 
            resnet.conv2, resnet.bn2, resnet.relu, 
            resnet.conv3, resnet.bn3, resnet.relu, 
            resnet.maxpool
        )
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        # layer3とlayer4の畳み込みフィルタのパラメータを変更する
        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)

            elif 'downsample.0' in n:
                m.stride = (1, 1)

        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        
        # Pyramid Pooling モジュール
        self.ppm = PyramidPoolingModule()
            
    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        # AuxLossのためにlayer3から出力を抜き出しておく
        x_tmp = self.layer3(x)
        x = self.layer4(x_tmp)

        # Pyramid Pooling モジュール
        x = self.ppm(x)
        
        return x


# 動作確認
if __name__ == '__main__':
    input = torch.rand(4, 3, 475, 475)
    model = PSPNet()
    model.eval()
    output = model(input)
    print(output.shape)