PyTorchで実装するPSPNet モジュール実装 UpSampling モジュール偏 ⑧

PSPNetを実装するにあたり、この記事ではUpSampling モジュールを実装していきます。

実装については、次の実装を参考にしています。

github.com

ディレクトリ構成は以下のような構成になっています。

modelpspnet.pyの中にUpSampling モジュールを実装していきます。

├─ data
|   └─ VOCdevkit
|      └─ VOC2012
|           ├─ Annotations
|           ├─ ImageSets
|           ├─ JPEGImages
|           ├─ SegmentationClass
|           └─ SegmentationObject
├─ initmodel
|   └─ resnet50_v2.pth
├─ model
|   ├─ pspnet.py         # ここにUpSampling モジュールを追記していく
|   └─ resnet.py         
└─ util
    └─ dataloader.py

次に今回使用するライブラリです。

import torch
import torch.nn as nn
import torch.nn.functional as F

最後に全体のコードも載せているので、実際に動かしながら確認していただければと思います。




1. UpSamplingモジュールの役割

f:id:venoda:20210816045506p:plain

Pyramid Pooling モジュールの出力をさらに畳み込みを行い入力画像と同じ大きさに変換します。

Pyramid Pooling モジュールの出力は (4096×60×60)なので、UpSamplingモジュールでは( 21(クラス数)×475(高さ)×475(幅) )に変換するということです。

次にUpSamplingモジュールの動きについて解説します。

以下の図を使って全体の動きを解説します。

f:id:venoda:20210816045632p:plain

処理の流れとして、Pyramid Pooling モジュールからの出力を入力にします。この時の入力のサイズは(4096×60×60)になります。

これを畳み込みフィルタを通して、(21×60×60)に変換します。

正解データとの誤差を計算するために入力画像と同じ大きさに拡大します。

この一連の処理をすることで、最終的な出力は(クラス数×高さ×幅)になります。

以上が、UpSamplingモジュールの役割と動きの解説になります。




2. PSPNetの実装

PSPNetにUpSamplingモジュールを追加していきます。

やることとしては畳み込みフィルとと画像拡大処理を追加するだけです。

ここではUpSamplingモジュールの追加のみですので、PSPNetの実装はまだ先に続きますのでご注意ください。

import sys
sys.path.append('../')

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.resnet import resnet50


class PSPNet(nn.Module):
    
    def __init__(self):
        super(PSPNet, self).__init__()

        # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する
        resnet = resnet50(pretrained=True)
        self.layer0 = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, 
            resnet.conv2, resnet.bn2, resnet.relu, 
            resnet.conv3, resnet.bn3, resnet.relu, 
            resnet.maxpool
        )
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        # layer3とlayer4の畳み込みフィルタのパラメータを変更する
        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)

            elif 'downsample.0' in n:
                m.stride = (1, 1)

        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        
        # Pyramid Pooling モジュール
        self.ppm = PyramidPoolingModule()

        # UpSampling モジュール
        self.cls = nn.Sequential(
            nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(512, 21, kernel_size=1)
        )
            
    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        # AuxLossのためにlayer3から出力を抜き出しておく
        x_tmp = self.layer3(x)
        x = self.layer4(x_tmp)

        # Pyramid Pooling モジュール
        x = self.ppm(x)

        # UpSampling モジュール
        x = self.cls(x)
        # 入力画像と同じ大きさに変換する
        x = F.interpolate(x, size=(475, 475), mode='bilinear', align_corners=True)

        return x

動作確認として、実際にモデルを動かしてみます。

# ダミーデータの作成
input = torch.rand(4, 3, 475, 475)

model = PSPNet()
model.eval()
output = model(input)

print(output.shape)
# Output
torch.Size([4, 21, 475, 475])

想定通り、torch.size([バッチサイズ, 21, 475, 475])の形で出力されています。

これで、UpSamplingモジュールの実装が完了です。


3. まとめ

今回の記事ではPSPNet内にUpSamplingモジュールを追加しました。

これだけではPSPNetは完成でないなので、さらに実装を続けていきます。

最後に全体のコードを載せておきます。




4. 参考




5. 全体コード

# model.pspnet.py
### ライブラリ
import sys
sys.path.append('../')

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.resnet import resnet50


### クラス定義
class PyramidPoolingModule(nn.Module):
    def __init__(self):
        super(PyramidPoolingModule, self).__init__()
        
        # 1×1スケール
        self.avgpool1 = nn.AdaptiveAvgPool2d(output_size=1)
        self.conv1 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(512)
        self.relu1 = nn.ReLU(inplace=True)

        # 2×2スケール
        self.avgpool2 = nn.AdaptiveAvgPool2d(output_size=2)
        self.conv2 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(512)
        self.relu2 = nn.ReLU(inplace=True)
        
        # 3×3スケール
        self.avgpool3 = nn.AdaptiveAvgPool2d(output_size=3)
        self.conv3 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(512)
        self.relu3 = nn.ReLU(inplace=True)
        
        # 4×4スケール
        self.avgpool6 = nn.AdaptiveAvgPool2d(output_size=6)
        self.conv6 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn6 = nn.BatchNorm2d(512)
        self.relu6 = nn.ReLU(inplace=True)
        
    def forward(self, x):
        # 1×1スケール
        out1 = self.avgpool1(x)
        out1 = self.conv1(out1)
        out1 = self.bn1(out1)
        out1 = self.relu1(out1)
        out1 = F.interpolate(out1, (60, 60), mode='bilinear', align_corners=True)
        
        # 2×2スケール
        out2 = self.avgpool2(x)
        out2 = self.conv2(out2)
        out2 = self.bn2(out2)
        out2 = self.relu2(out2)
        out2 = F.interpolate(out2, (60, 60), mode='bilinear', align_corners=True)
        
        # 3×3スケール
        out3 = self.avgpool3(x)
        out3 = self.conv3(out3)
        out3 = self.bn3(out3)
        out3 = self.relu3(out3)
        out3 = F.interpolate(out3, (60, 60), mode='bilinear', align_corners=True)
        
        # 6×6スケール
        out6 = self.avgpool6(x)
        out6 = self.conv6(out6)
        out6 = self.bn6(out6)
        out6 = self.relu6(out6)
        out6 = F.interpolate(out6, (60, 60), mode='bilinear', align_corners=True)
        
        # 元の入力と各スケールの特徴量を結合させる
        out = torch.cat([x, out1, out2, out3, out6], dim=1)
        
        return out


class PSPNet(nn.Module):
    
    def __init__(self):
        super(PSPNet, self).__init__()

        # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する
        resnet = resnet50(pretrained=True)
        self.layer0 = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, 
            resnet.conv2, resnet.bn2, resnet.relu, 
            resnet.conv3, resnet.bn3, resnet.relu, 
            resnet.maxpool
        )
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        # layer3とlayer4の畳み込みフィルタのパラメータを変更する
        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)

            elif 'downsample.0' in n:
                m.stride = (1, 1)

        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        
        # Pyramid Pooling モジュール
        self.ppm = PyramidPoolingModule()

        # UpSampling モジュール
        self.cls = nn.Sequential(
            nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(512, 21, kernel_size=1)
        )
            
    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        # AuxLossのためにlayer3から出力を抜き出しておく
        x_tmp = self.layer3(x)
        x = self.layer4(x_tmp)

        # Pyramid Pooling モジュール
        x = self.ppm(x)

        # UpSampling モジュール
        x = self.cls(x)
        # 入力画像と同じ大きさに変換する
        x = F.interpolate(x, size=(475, 475), mode='bilinear', align_corners=True)

        return x


# 動作確認
if __name__ == '__main__':
    input = torch.rand(4, 3, 475, 475)
    model = PSPNet()
    model.eval()
    output = model(input)
    print(output.shape)