PyTorchで実装するPSPNet モジュール実装 FeatureMapモジュール偏 ⑥

PSPNetを実装するにあたり、この記事ではFeatureMapモジュールを実装していきます。

実装については、次の実装を参考にしています。

github.com

ディレクトリ構成は以下のような構成になっています。

initmodelの中には学習済みのResNetのパラメータを保存します。

modelの中にResNetとPSPNetを実装していきます。

├─ data
|   └─ VOCdevkit
|      └─ VOC2012
|           ├─ Annotations
|           ├─ ImageSets
|           ├─ JPEGImages
|           ├─ SegmentationClass
|           └─ SegmentationObject
├─ initmodel
|   └─ resnet50_v2.pth   # 学習済みモデルはここに保存する
├─ model
|   ├─ pspnet.py         # PSPNetはここに記述していく
|   └─ resnet.py         # ResNetはここに記述していく
└─ util
    └─ dataloader.py

次に今回使用するライブラリです。

import torch
import torch.nn as nn
from torchvision import models

最後に全体のコードも載せているので、実際に動かしながら確認していただければと思います。




1. FeatureMapモジュールの役割

f:id:venoda:20210815092631p:plain

このモジュールは、入力画像から最初にCNNによる畳込みを行い特徴量を抽出します。

この畳み込みで使用するモデルは、学習済みのResNetを使用しています。

Feature Map モジュールではサイズが (3 × 475 ×475) の入力画像を、(2048 × 60 × 60) に変換します。


2. ResNetの実装

PSPNetで使用している学習済みモデルはResNetを使用しています。

まずはじめに、PSPNet内で使用するResNetを実装していきます。

いくつか実装上の注意点を説明した後に、実装に移っていきたいと思います。


2.1. 使用する学習済みモデルについて

PyTorchのtorchvison.models.resent50で実装されているResNetと、PSPNet内で使用されているResNetはネットワーク構成に少し差異があります。

PyTorchのtorchvison.models.resent50で実装されているResNetは最初の畳み込みが7×7になっています。

それに対して、PSPNet内で使用されているResNetは最初の畳み込みが3×3になっています。

これは分類精度が3×3にしたほうが優れているらしく、こちらを採用しているらしいです。

実際にそれぞれのResNetの構成を一部抜粋して、違いを確認してみます。

まずは、PyTorchのtorchvison.models.resent50で実装されているResNetのネットワーク構成です。

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  ・・・・・
  ・・・・・
)

次に、PSPNet内で使用されているResNetのネットワーク構成です。

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  ・・・・・
  ・・・・・
)

実際にネットワーク構成に差異があるのがわかると思います。

今回はPSPNetで使用しているResNetを実装していきたいと思います。


2.2. 学習済みモデルのダウンロード

学習済みResNetのパラメータは別途ダウンロードする必要があります。

ダウンロード先は以下のURLです。resnet50_v2.pthをダウンロードしてください。

https://drive.google.com/drive/folders/1Hrz1wOxOZm4nIIS7UMJeL79AQrdvpj6v

ダウンロードしたら、initmodelの下に格納してください。

├─ data
|   └─ VOCdevkit
|      └─ VOC2012
|           ├─ Annotations
|           ├─ ImageSets
|           ├─ JPEGImages
|           ├─ SegmentationClass
|           └─ SegmentationObject
├─ initmodel
|   └─ resnet50_v2.pth   # ← ダウンロードしたパラメータファイルをここに格納する
├─ model
|   ├─ pspnet.py
|   └─ resnet.py
└─ util
    └─ dataloader.py


2.3. 実装

ゼロからResNetを実装していくのは非常に大変なので、基本的にはtorchvison.models.resent50で実装されているResNetを使いまわしつつ、差異部分は新規作成または変更していく方針としています。

nn.Moduleを継承させたResNetクラスを作成していきます。

実装としては次のようになります。新規作成箇所、使いまわし箇所、変更箇所をコメントで記載しています。

import torch
import torch.nn as nn
from torchvision import models

class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        
        # 新規作成
        # PSPNetで使用するResNetの最初の畳み込み層の定義を行う
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, stride=2, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, stride=1, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, stride=1, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # 使いまわし
        # PyTorchで実装済みのResNetを読み込み、必要なところだけを使用する
        model = models.resnet50(pretrained=False)
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4
        self.avgpool = model.avgpool
        self.fc = model.fc
        
        # 変更
        # Layer1をPSPNetで使用するResNetの実装に合わせるため、一部、Conv2dのパラメータを変更する
        self.layer1[0].conv1 = nn.Conv2d(in_channels=128, out_channels=64, stride=1, bias=False, kernel_size=1)
        self.layer1[0].downsample[0] = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, stride=1, bias=False)
        
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

モデルのパラメータファイルを読み込むことができれば、問題なくResNetが実装できています。

うまく読み込むことができれば、<All keys matched successfully>が出力されます。

モデルとパラメータファイルに差異があると、例外が発生します。

# ダウンロードしたモデルのパラメータファイル
model_path = '../initmodel/resnet50_v2.pth'

resnet = ResNet()
resnet.load_state_dict(torch.load(model_path), strict=False)
# Output
<All keys matched successfully>

最後に、PSPNetからResNetを呼び出すときに、モデル読み込みまでを行う関数を作成しておきます。

def resnet50(pretrained=False):
    model = ResNet()
    
    if pretrained is True:
        model_path = '../initmodel/resnet50_v2.pth'
        model.load_state_dict(torch.load(model_path), strict=False)
        
    return model

これにてResNetの実装が完了です。




3. PSPNetの実装

実装したResNetを使用して、PSPNetのFeature Map モジュールを実装していきます。

ここではFeature Map モジュールの実装のみですので、PSPNetの実装はまだ先に続きますのでご注意ください。


3.1. 実装上の注意点:畳み込みフィルタのパラメータ変更

まず最初に、畳み込みフィルタにおけるdilationというものを説明します。

dilationとは、畳み込みフィルタのセル間隔を設定するものになります。

kernel_size=3のフィルタの場合、dilationが1と2の場合の動きの違いを図に示します。

f:id:venoda:20210815092711p:plain

図から見てわかるように、dilationの効果として、より広範囲のセルの影響を考慮することができます。

通常の畳み込みフィルタではdilationが1になっています。PSPNetでは大局的な特徴を抽出するために、dilationの値を特別に設定します。

そのため、ResNetのlayer3とlayer4の一部の層のパラメータを変更する処理を行っています。


3.2. 実装上の注意点:途中の出力を抜き出す

forward処理のところで、layer3からの出力を別途抜き出しています。

これは後で解説するAuxLossモジュールのためです。

def forward(self, x):
    x = self.layer0(x)
    x = self.layer1(x)
    x = self.layer2(x)
    # AuxLossのためにlayer3から出力を抜き出しておく
    x_tmp = self.layer3(x)
    x = self.layer4(x_tmp)

    return x


3.3. 実装

nn.Moduleを継承させたPSPNetクラスを作成していきます。

上で説明した注意点をもとに実装します。

PSPNetではResNetを全部使用するわけではなく、最初の畳み込みフィルタとlayer1からlayer4までを使用します。

import sys
sys.path.append('../')

import torch
import torch.nn as nn
from model.resnet import resnet50


class PSPNet(nn.Module):
    
    def __init__(self):
        super(PSPNet, self).__init__()

        # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する
        resnet = resnet50(pretrained=True)
        self.layer0 = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, 
            resnet.conv2, resnet.bn2, resnet.relu, 
            resnet.conv3, resnet.bn3, resnet.relu, 
            resnet.maxpool
        )
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        # layer3とlayer4の畳み込みフィルタのパラメータを変更する
        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)

            elif 'downsample.0' in n:
                m.stride = (1, 1)

        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
                
            
    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        # AuxLossのためにlayer3から出力を抜き出しておく
        x_tmp = self.layer3(x)
        x = self.layer4(x_tmp)
        
        return x

動作確認として、実際にモデルを動かしみます。

# ダミーデータの作成
input = torch.rand(4, 3, 475, 475)

model = PSPNet()
model.eval()
output = model(input)

print(output.shape)
# Output
torch.Size([4, 2048, 60, 60])

想定通り、torch.size([バッチサイズ, 2048, 60, 60])の形で出力されています。

これでFeature Map モジュールの実装が完了です。




4. まとめ

今回の記事ではResNetの実装と、PSPNet内のFeature Map モジュールを実装しました。

これだけではPSPNetは完成でないなので、さらに実装を続けていきます。

最後に全体のコードを載せておきます。


5. 参考文献




6. 全体コード

# model.resnet.py
### ライブラリ
import torch
import torch.nn as nn

from torchvision import models


### クラス定義
class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()

        # 新規作成
        # PSPNetで使用するResNetの最初の畳み込み層の定義を行う
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, stride=2, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, stride=1, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, stride=1, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 使いまわし
        # PyTorchで実装済みのResNetを読み込み、必要なところだけを使用する
        model = models.resnet50(pretrained=False)
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4
        self.avgpool = model.avgpool
        self.fc = model.fc

        # 変更
        # Layer1をPSPNetで使用するResNetの実装に合わせるため、一部、Conv2dのパラメータを変更する
        self.layer1[0].conv1 = nn.Conv2d(in_channels=128, out_channels=64, stride=1, bias=False, kernel_size=1)
        self.layer1[0].downsample[0] = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, stride=1, bias=False)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


### 関数定義
def resnet50(pretrained=False):
    model = ResNet()
    
    if pretrained is True:
        model_path = '../initmodel/resnet50_v2.pth'
        model.load_state_dict(torch.load(model_path), strict=False)
        
    return model


# 動作確認
if __name__ == '__main__':
    input = torch.rand(4, 3, 475, 475)
    model = resnet50(pretrained=True)
    model.eval()
    output = model(input)
    print(output.shape)
# model.pspnet.py
### ライブラリ
import sys
sys.path.append('../')

import torch
import torch.nn as nn

from model.resnet import resnet50


### クラス定義
class PSPNet(nn.Module):
    
    def __init__(self):
        super(PSPNet, self).__init__()

        # ResNetから最初の畳み込みフィルタとlayer1からlayer4を取得する
        resnet = resnet50(pretrained=True)
        self.layer0 = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, 
            resnet.conv2, resnet.bn2, resnet.relu, 
            resnet.conv3, resnet.bn3, resnet.relu, 
            resnet.maxpool
        )
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        # layer3とlayer4の畳み込みフィルタのパラメータを変更する
        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)

            elif 'downsample.0' in n:
                m.stride = (1, 1)

        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
                
            
    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        # AuxLossのためにlayer3から出力を抜き出しておく
        x_tmp = self.layer3(x)
        x = self.layer4(x_tmp)
        
        return x


# 動作確認
if __name__ == '__main__':
    input = torch.rand(4, 3, 475, 475)
    model = PSPNet()
    model.eval()
    output = model(input)
    print(output.shape)