実際にPSPNetを実装していく前に、PSPNetの全体像と各モジュールの役割を解説します。
それぞれのモジュールがどのような機能を持つかを学びます。
1. 全体像
2. FeatureMap モジュール
このモジュールは、入力画像から最初にCNNによる畳込みを行い特徴量を抽出します。
この畳み込みで使用するモデルは、学習済みのResNetを使用しています。
3. Pyramid Pooling モジュール
Feature Map モジュールで抽出した特徴量を、Poolingを行い様々なスケールに変換します。
論文中では1×1、2×2、3×3、6×6のスケールの層を作成しています。
異なるスケールの特徴量を作成することで、全体を考慮するまたは部分的に考慮するなどといったPSPNetの強みであるピクセルの周辺情報を考慮した特徴量を作成することができる。
それぞれのスケールに変換したら、それをさらに畳み込みます。
4. UpSampling モジュール
Pyramid Pooling モジュールで小さくなった特徴量の出力を拡大します。
拡大したそれぞれのスケールの特徴量と元の特徴量を結合させます。
結合させたら、入力画像のサイズと同じ大きさに変換します。
5. AuxLoss モジュール
Feature Map モジュール、Pyramid Pooling モジュール、UpSampling モジュールの3つのモジュールでセグメンテーションモデルを実現することはできますが、AuxLoss モジュールを追加することで、損失関数の計算補助を行います。
Feature Map モジュールの途中からとってきた出力を、UpSampling モジュールと同様に入力画像のサイズと同じ大きさに変換します。
AuxLoss モジュールとUpSampling モジュールの二つの出力を、アノテーションデータと対応させて損失値を計算し、バックプロパゲーションを行います。
ここで注意点として、AuxLoss モジュールは学習時のみ使用します。推論時はAuxLoss モジュールは使用しないで予測をお行います。
6. 参考文献
論文
-
この論文の内容を紹介しています
-
-
https://github.com/hszhao/semseg
実装方法については、こちらのGithubのコードを参考にしています。
書籍
つくりながら学ぶ!PyTorchによる発展ディープラーニング
ファインチューニングの方法はこちらの書籍を参考にしています。