PyTorchで実装するPSPNet モジュール実装 AuxLoss モジュール偏 ⑨

PSPNetを実装するにあたり、この記事ではAuxLossモジュールを実装していきます。

実装については、次の実装を参考にしています。

github.com

ディレクトリ構成は以下のような構成になっています。

modelpspnet.pyの中にAuxLossモジュールを実装していきます。

├─ data
|   └─ VOCdevkit
|      └─ VOC2012
|           ├─ Annotations
|           ├─ ImageSets
|           ├─ JPEGImages
|           ├─ SegmentationClass
|           └─ SegmentationObject
├─ initmodel
|   └─ resnet50_v2.pth
├─ model
|   ├─ pspnet.py         # ここにAuxLossモジュールを追記していく
|   └─ resnet.py         
└─ util
    └─ dataloader.py

次に今回使用するライブラリです。

import torch
import torch.nn as nn
import torch.nn.functional as F

最後に全体のコードも載せているので、実際に動かしながら確認していただければと思います。




1. AuxLossモジュールの役割

f:id:venoda:20210818213129p:plain

Feature Map モジュール、Pyramid Pooling モジュール、UpSampling モジュールの3つのモジュールでセグメンテーションモデルを実現することはできますが、AuxLoss モジュールを追加することで、損失関数の計算補助を行います。

Feature Map モジュールの途中からとってきた出力を、UpSampling モジュールと同様に入力画像のサイズと同じ大きさに変換します。

AuxLoss モジュールとUpSampling モジュールの二つの出力を、アノテーションデータと対応させて損失値を計算し、バックプロパゲーションを行うことでAuxLossモジュールが計算の補助を行うというわけです。


次にAuxLossモジュールの動きについて解説します。

以下の図を使って全体の動きを解説します。

f:id:venoda:20210818213143p:plain

まずはメインとなるのは、FatureMapモジュール、Pyramid Poolingモジュール、UpSamplingモジュールを通るルートになります。

これがメインとなるルートで最終的な出力とアノテーションデータとの損失(Loss 1)を計算します。


AuxLossモジュールは、FeatureMapモジュールのLayer3を途中で抜き出し、分岐させる形でAuxLossモジュールに渡します。

FeatureMapモジュールの途中の出力を渡したら畳み込みを行い入力画像と同じ大きさに画像を拡大させます。

このAuxLossモジュールからの出力とアノテーションデータとの損失(Loss 2)を計算します。

当然、FeatureMapモジュールの途中までの出力を使用するので分類精度としては低くなります。

ディープラーニングではネットワークが深くなればなるほど学習が困難になる傾向があり、ResNetもネットワークが深いので学習が難しいモデルになります。

そのためAuxLossモジュールではFeatureMapモジュールの途中の出力を抜き出して損失を計算することによって、ResNetの前半の層(Layer3より前)の学習を補助する役割を担います。


学習時にはメインルートの損失(Loss 1)で全体を最適化しつつ、AuxLossモジュールの損失(Loss 2)でResNetの前半の層(Layer3より前)を補助しながら学習を進めます。

以上が、AuxLossモジュールの役割と動きの解説になります。




2. PSPNetの実装

2.1. 実装上の注意点:学習時と推論時の動作を制御する

AuxLossモジュールはあくまでも学習時に補助的に使用するため、学習時には使用しますが推論時には使用しません。

つまり、PSPNetでは学習時または推論時で動作を変える必要があります。

やり方としては簡単で、nn.Moduleを継承させたクラスを作成すると、変数としてtrainingを持っています。

これは学習時であればTrueを返し、推論時であればFalseを返します。

この変数を使用して、学習時にはAuxLossモジュールを使用し、推論時にはAuxLossモジュールを使用しないように動作を制御します。

こんな風に書くだけで制御できます。

# 学習時かどうかはクラスの変数として持っているので、self.trainingを参照するだけでOK
if self.training is True:
    self.aux = nn.Sequential(
        nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.Dropout(0.1),
        nn.Conv2d(256, 21, kernel_size=21)
    )

ちなみに、PyTorchで作成したモデルを学習か推論かを指定するには次のように設定します。

model = PSPNet()

# 学習用に設定
model.train()

# 推論用に設定
model.eval()


2.2. 実装

PSPNetにAuxLossモジュールを追加していきます。

やることとしてはLayer3から出力を抜き出してAuxLossモジュールに渡しやることと、学習時のみAuxLossモジュールを使用するように制御することです。

加えて、モデル内のパラメータはわかりやすさを重視して今までべた書きしていましたが、今後の学習を見据えて分類するクラス数を指定できるように変更しておきます。

import sys
sys.path.append('../')

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.resnet import resnet50


class PSPNet(nn.Module):
    
    def __init__(self, n_classes=21):
        super(PSPNet, self).__init__()
        self.n_classes = n_classes

        # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する
        resnet = resnet50(pretrained=True)
        self.layer0 = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, 
            resnet.conv2, resnet.bn2, resnet.relu, 
            resnet.conv3, resnet.bn3, resnet.relu, 
            resnet.maxpool
        )
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        # layer3とlayer4の畳み込みフィルタのパラメータを変更する
        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)

            elif 'downsample.0' in n:
                m.stride = (1, 1)

        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        
        # Pyramid Pooling モジュール
        self.ppm = PyramidPoolingModule()

        # UpSampling モジュール
        self.cls = nn.Sequential(
            nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(512, self.n_classes, kernel_size=1)
        )

        # 学習時にのみAuxLossモジュールを使用するように設定
        if self.training is True:
            self.aux = nn.Sequential(
                nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1),
                nn.Conv2d(256, self.n_classes, kernel_size=1)
            )
            
    def forward(self, x, y=None):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        # AuxLossのためにlayer3から出力を抜き出しておく
        x_tmp = self.layer3(x)
        x = self.layer4(x_tmp)

        # Pyramid Pooling モジュール
        x = self.ppm(x)

        # UpSampling モジュール
        x = self.cls(x)
        # 入力画像と同じ大きさに変換する
        x = F.interpolate(x, size=(475, 475), mode='bilinear', align_corners=True)

        # 学習時にのみAuxLossモジュールを使用するように設定
        if self.training is True:
            aux = self.aux(x_tmp)
            aux = F.interpolate(aux, size=(475, 475), mode='bilinear', align_corners=True)
            return x, aux

        return x

動作確認として、実際にモデルを動かしてみます。

# ダミーデータの作成
input = torch.rand(4, 3, 475, 475)

model = PSPNet()

# モデルを学習用に設定した場合
model.train()
output, aux = model(input)
print('学習時')
print(output.shape)
print(aux.shape)

# モデルを推論用に設定した場合
model.eval()
output = model(input)
print('推論時')
print(output.shape)
# Output
学習時
torch.Size([4, 21, 475, 475])
torch.Size([4, 21, 475, 475])
推論時
torch.Size([4, 21, 475, 475])

想定通り、学習時には二つの出力が返りtorch.size([バッチサイズ, 21, 475, 475])の形で出力されています。

一方で推論時には一つの出力が返りtorch.size([バッチサイズ, 21, 475, 475])の形で出力されています。

これで、AuxLossモジュールの実装が完了です。




3. まとめ

今回の記事ではPSPNet内にAuxLossモジュールを追加しました。

これでPSPNetは完成になります。次の記事からは実装したモデルを学習させていきます。

最後に全体のコードを載せておきます。


4. 参考文献




5. 全体コード

# model.pspnet.py
### ライブラリ
import sys

from torch.nn.modules import padding
sys.path.append('../')

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.resnet import resnet50


### クラス定義
class PyramidPoolingModule(nn.Module):
    def __init__(self):
        super(PyramidPoolingModule, self).__init__()
        
        # 1×1スケール
        self.avgpool1 = nn.AdaptiveAvgPool2d(output_size=1)
        self.conv1 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(512)
        self.relu1 = nn.ReLU(inplace=True)

        # 2×2スケール
        self.avgpool2 = nn.AdaptiveAvgPool2d(output_size=2)
        self.conv2 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(512)
        self.relu2 = nn.ReLU(inplace=True)
        
        # 3×3スケール
        self.avgpool3 = nn.AdaptiveAvgPool2d(output_size=3)
        self.conv3 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(512)
        self.relu3 = nn.ReLU(inplace=True)
        
        # 4×4スケール
        self.avgpool6 = nn.AdaptiveAvgPool2d(output_size=6)
        self.conv6 = nn.Conv2d(2048, 512, kernel_size=1, bias=False)
        self.bn6 = nn.BatchNorm2d(512)
        self.relu6 = nn.ReLU(inplace=True)
        
    def forward(self, x):
        # 1×1スケール
        out1 = self.avgpool1(x)
        out1 = self.conv1(out1)
        out1 = self.bn1(out1)
        out1 = self.relu1(out1)
        out1 = F.interpolate(out1, (60, 60), mode='bilinear', align_corners=True)
        
        # 2×2スケール
        out2 = self.avgpool2(x)
        out2 = self.conv2(out2)
        out2 = self.bn2(out2)
        out2 = self.relu2(out2)
        out2 = F.interpolate(out2, (60, 60), mode='bilinear', align_corners=True)
        
        # 3×3スケール
        out3 = self.avgpool3(x)
        out3 = self.conv3(out3)
        out3 = self.bn3(out3)
        out3 = self.relu3(out3)
        out3 = F.interpolate(out3, (60, 60), mode='bilinear', align_corners=True)
        
        # 6×6スケール
        out6 = self.avgpool6(x)
        out6 = self.conv6(out6)
        out6 = self.bn6(out6)
        out6 = self.relu6(out6)
        out6 = F.interpolate(out6, (60, 60), mode='bilinear', align_corners=True)
        
        # 元の入力と各スケールの特徴量を結合させる
        out = torch.cat([x, out1, out2, out3, out6], dim=1)
        
        return out


class PSPNet(nn.Module):
    
    def __init__(self, n_classes=21):
        super(PSPNet, self).__init__()
        self.n_classes = n_classes

        # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する
        resnet = resnet50(pretrained=True)
        self.layer0 = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, 
            resnet.conv2, resnet.bn2, resnet.relu, 
            resnet.conv3, resnet.bn3, resnet.relu, 
            resnet.maxpool
        )
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        # layer3とlayer4の畳み込みフィルタのパラメータを変更する
        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)

            elif 'downsample.0' in n:
                m.stride = (1, 1)

        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        
        # Pyramid Pooling モジュール
        self.ppm = PyramidPoolingModule()

        # UpSampling モジュール
        self.cls = nn.Sequential(
            nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(512, self.n_classes, kernel_size=1)
        )

        # 学習時にのみAuxLossモジュールを使用するように設定
        if self.training is True:
            self.aux = nn.Sequential(
                nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1),
                nn.Conv2d(256, self.n_classes, kernel_size=1)
            )
            
    def forward(self, x, y=None):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        # AuxLossのためにlayer3から出力を抜き出しておく
        x_tmp = self.layer3(x)
        x = self.layer4(x_tmp)

        # Pyramid Pooling モジュール
        x = self.ppm(x)

        # UpSampling モジュール
        x = self.cls(x)
        # 入力画像と同じ大きさに変換する
        x = F.interpolate(x, size=(475, 475), mode='bilinear', align_corners=True)

        # 学習時にのみAuxLossモジュールを使用するように設定
        if self.training is True:
            aux = self.aux(x_tmp)
            aux = F.interpolate(aux, size=(475, 475), mode='bilinear', align_corners=True)
            return x, aux

        return x


# 動作確認
if __name__ == '__main__':
    input = torch.rand(4, 3, 475, 475)
    model = PSPNet()

    # モデルを学習用に設定した場合
    model.train()
    output, aux = model(input)
    print('学習時')
    print(output.shape)
    print(aux.shape)

    # モデルを推論用に設定した場合
    model.eval()
    output = model(input)
    print('推論時')
    print(output.shape)