PyTorchによるファインチューニングの実装

1. 概要

PyTorchを使ってファインチューニングによる画像分類を実装していきたいと思います。

今回はVGG16を使ってモデルを実装していきます。

2. モデル化の流れ

PyTorchは次の流れでモデル化していけば大きく間違えることはないかと思います。

  1. 準備
    1. データ準備
    2. 前処理
    3. Datasetの作成
    4. DataLoaderの作成
  2. ネットワークの作成
    1. ネットワークの定義
    2. 損失関数の定義
    3. 最適化手法の定義
  3. 学習と予測
    1. 学習の実行
    2. 予測の実行

では流れに沿って実装していきたいと思います。



3. 準備

3.1. データ準備

今回使用するデータは、犬のデータセットを使用して画像分類を試みます。

まずは下記のサイトにアクセスして画像データをダウンロードします。

http://vision.stanford.edu/aditya86/ImageNetDogs/

f:id:venoda:20201018014456p:plain

「Images」のリンクをクリックすれば、画像データのダウンロードが開始されます。

ダウンロードしたデータを解凍すると、120種類の犬種がディレクトリ別に格納されています。

すべてのデータを使うと計算リソースを必要とするので、次の5犬種に絞ってモデルを構築していきます。

・チワワ(Chihuahua)

シーズー(Shih-Tzu)

ボルゾイ(borzoi)

・パグ(pug)

グレートデン(Great-Dane)

ディレクトリ構成は下記のようにしてあります。

├── Images  # 画像データ
│   ├── n02085620-Chihuahua
│   ├── n02086240-Shih-Tzu
│   ├── n02090622-borzoi
│   ├── n02109047-Great_Dane
│   └── n02110958-pug
└── PyTorchによる転移学習の実装.ipynb  # 実装スクリプト

最後に今回使用するライブラリです。

# ライブラリの読み込み
import os
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms

3.2. 前処理

Datasetにデータを渡す準備として、前処理をしていきます。

まずは、用意したデータセットを学習データと検証データに分割します。

格納されているデータ数は各犬種ごとに異なるので、それぞれ学習データを80%、検証データを20%にします。

そして、学習データ、検証データそれぞれのファイルパスを格納したリストを用意します。

def make_filepath_list():
    """
    学習データ、検証データそれぞれのファイルへのパスを格納したリストを返す
    
    Returns
    -------
    train_file_list: list
        学習データファイルへのパスを格納したリスト
    valid_file_list: list
        検証データファイルへのパスを格納したリスト
    """
    train_file_list = []
    valid_file_list = []

    for top_dir in os.listdir('./Images/'):
        file_dir = os.path.join('./Images/', top_dir)
        file_list = os.listdir(file_dir)

        # 各犬種ごとに8割を学習データ、2割を検証データとする
        num_data = len(file_list)
        num_split = int(num_data * 0.8)

        train_file_list += [os.path.join('./Images', top_dir, file).replace('\\', '/') for file in file_list[:num_split]]
        valid_file_list += [os.path.join('./Images', top_dir, file).replace('\\', '/') for file in file_list[num_split:]]
    
    return train_file_list, valid_file_list

# 画像データへのファイルパスを格納したリストを取得する
train_file_list, valid_file_list = make_filepath_list()

print('学習データ数 : ', len(train_file_list))
# 先頭3件だけ表示
print(train_file_list[:3])

print('検証データ数 : ', len(valid_file_list))
# 先頭3件だけ表示
print(valid_file_list[:3])
# Output
学習データ数 :  696
['./Images/n02085620-Chihuahua/n02085620_10074.jpg', './Images/n02085620-Chihuahua/n02085620_10131.jpg', './Images/n02085620-Chihuahua/n02085620_10621.jpg']
検証データ数 :  177
['./Images/n02085620-Chihuahua/n02085620_588.jpg', './Images/n02085620-Chihuahua/n02085620_5927.jpg', './Images/n02085620-Chihuahua/n02085620_6295.jpg']

次に画像データに対する前処理(resizeなど)の処理を記述したクラスを作成します。

このクラスに画像データを通せば、指定した前処理を施したデータがえられます。

{train: ~, valid: ~}としている理由は、学習時と推論時で実施する前処理を変えるためです。

学習時にはモデルの性能を高めるために、データオーグメンテーション用の前処理を加えていますが、推論時にはデータオーグメンテーションする必要がないため、画像のサイズを整えるなどの処理だけにとどめています。

class ImageTransform(object):
    """
    入力画像の前処理クラス
    画像のサイズをリサイズする
    
    Attributes
    ----------
    resize: int
        リサイズ先の画像の大きさ
    mean: (R, G, B)
        各色チャンネルの平均値
    std: (R, G, B)
        各色チャンネルの標準偏差
    """
    def __init__(self, resize, mean, std):
        self.data_trasnform = {
            'train': transforms.Compose([
                # データオーグメンテーション
                transforms.RandomHorizontalFlip(),
                # 画像をresize×resizeの大きさに統一する
                transforms.Resize((resize, resize)),
                # Tensor型に変換する
                transforms.ToTensor(),
                # 色情報の標準化をする
                transforms.Normalize(mean, std)
            ]),
            'valid': transforms.Compose([
                # 画像をresize×resizeの大きさに統一する
                transforms.Resize((resize, resize)),
                # Tensor型に変換する
                transforms.ToTensor(),
                # 色情報の標準化をする
                transforms.Normalize(mean, std)
            ])
        }
    
    def __call__(self, img, phase='train'):
        return self.data_trasnform[phase](img)

# 動作確認
img = Image.open('./Images/n02085620-Chihuahua/n02085620_199.jpg')

# リサイズ先の画像サイズ
resize = 300

# 今回は簡易的に(0.5, 0.5, 0.5)で標準化
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

transform = ImageTransform(resize, mean, std)
img_transformed = transform(img, 'train')

plt.imshow(img)
plt.show()

plt.imshow(img_transformed.numpy().transpose((1, 2, 0)))
plt.show()

実際にImageTransformクラスを実行してみると、画像がリサイズされ色も標準化されていることがわかります。

加えて、phaseパラメータをtrainに設定して何回か実行すると、画像が左右反転したデータも出力されることがわかります。

3.3. Datasetの作成

PyTorchのDatasetクラスを継承させたクラスを作ります。

処理を記述する箇所は「__len__」「__getitem__」の二つです。

「__len__」にはDatasetに含まれるデータ数を返す処理を記述します。

「__getitem__」にはindex番号を引数にとり、学習データとラベルデータを返す処理を記述します。

class DogDataset(data.Dataset):
    """
    犬種のDataseクラス。
    PyTorchのDatasetクラスを継承させる。
    
    Attrbutes
    ---------
    file_list: list
        画像のファイルパスを格納したリスト
    classes: list
        犬種のラベル名
    transform: object
        前処理クラスのインスタンス
    phase: 'train' or 'valid'
        学習か検証化を設定
    """
    def __init__(self, file_list, classes, transform=None, phase='train'):
        self.file_list = file_list
        self.transform = transform
        self.classes = classes
        self.phase = phase
    
    def __len__(self):
        """
        画像の枚数を返す
        """
        return len(self.file_list)
    
    def __getitem__(self, index):
        """
        前処理した画像データのTensor形式のデータとラベルを取得
        """
        # 指定したindexの画像を読み込む
        img_path = self.file_list[index]
        img = Image.open(img_path)
        
        # 画像の前処理を実施
        img_transformed = self.transform(img, self.phase)
        
        # 画像ラベルをファイル名から抜き出す
        label = self.file_list[index].split('/')[2][10:]
        
        # ラベル名を数値に変換
        label = self.classes.index(label)
        
        return img_transformed, label


# 動作確認
# クラス名
dog_classes = [
    'Chihuahua',  'Shih-Tzu',
    'borzoi', 'Great_Dane', 'pug'
]

# リサイズ先の画像サイズ
resize = 300

# 今回は簡易的に(0.5, 0.5, 0.5)で標準化
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

# Datasetの作成
train_dataset = DogDataset(
    file_list=train_file_list, classes=dog_classes,
    transform=ImageTransform(resize, mean, std),
    phase='train'
)

valid_dataset = DogDataset(
    file_list=valid_file_list, classes=dog_classes,
    transform=ImageTransform(resize, mean, std),
    phase='valid'
)

index = 0
print(train_dataset.__getitem__(index)[0].size())
print(train_dataset.__getitem__(index)[1])
# Output
torch.Size([3, 300, 300])
0

3.4. DataLoaderの作成

バッチ処理を適用するためにDataLoaderを作成して、データセットをバッチ単位で取り出せるようにします。

のちの学習のために、学習用のDataLoaderと検証用のDataLoaderを辞書にまとめています。

# バッチサイズの指定
batch_size = 64

# DataLoaderを作成
train_dataloader = data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

valid_dataloader = data.DataLoader(
    valid_dataset, batch_size=32, shuffle=False)

# 辞書にまとめる
dataloaders_dict = {
    'train': train_dataloader, 
    'valid': valid_dataloader
}

# 動作確認
# イテレータに変換
batch_iterator = iter(dataloaders_dict['train'])

# 1番目の要素を取り出す
inputs, labels = next(batch_iterator)

print(inputs.size())
print(labels)
# Output
torch.Size([32, 3, 300, 300])
tensor([2, 0, 2, 4, 4, 3, 4, 4, 3, 3, 1, 1, 3, 4, 4, 0, 1, 4, 0, 4, 3, 1, 1, 0,
        4, 4, 3, 0, 1, 1, 1, 3])



4. ネットワークの作成

4.1. ネットワークの定義

ネットワークを定義するために、VGG16モデルをロードします。

PyTorchを使えば簡単にロードすることができます。

初回のロードの場合、ダウンロードが発生するので少し時間がかかるかもしれません。

# 学習済みの重みを使用
use_pretrained = True

# モデルをロード
net = models.vgg16(pretrained=use_pretrained)

print(net)
# Output
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

次に、モデルを今回のタスクに適応させるために最後の出力層を書き換えます。

モデルは内部的に、(features)、(avgpool)、(classifier)と分かれているので、リストを操作する感じで操作できます。

書き換える際は (classifier) の (6) の Linear(in_features=4096, out_features=1000, bias=True)を

Linear(in_features=4096, out_features=5, bias=True) に書き換えます。

print('変更前 : ', net.classifier[6])

net.classifier[6] = nn.Linear(in_features=4096, out_features=5)

print('変更前 : ', net.classifier[6])
# Output
変更前 :  Linear(in_features=4096, out_features=1000, bias=True)
変更前 :  Linear(in_features=4096, out_features=5, bias=True)

4.2. 損失関数の定義

損失関数を定義します。

今回はクロスエントロピー誤差を使用します。

criterion = nn.CrossEntropyLoss()

4.3. 最適化手法の定義

今回はファインチューニングを実装するので、更新するモデルの重みを指定していきます。

ファインチューニングの場合、全層の重みを更新できるように設定します。

この時に入力に近い層の重みをなるべく更新しないようにしたいので、各層ごとに学習率を変えれるように設定していきます。

今回の場合はfeatureモジュールはあまり重みを更新させたくないので、入力に近い層は学習率を小さく設定します。

まずは重みがどのように格納されているか確認します。

list(net.classifier[6].named_parameters())
# Output
[('weight', Parameter containing:
  tensor([[ 1.1528e-02, -2.1545e-03,  1.3566e-03,  ..., -5.5306e-03,
            6.0829e-03, -5.3546e-03],
          [-2.1816e-03, -3.8899e-04,  9.2145e-03,  ..., -5.5057e-03,
           -5.2748e-05, -9.0574e-03],
          [ 5.4369e-03, -1.3873e-02,  1.2081e-03,  ...,  1.3849e-02,
           -1.0094e-02,  2.4163e-03],
          [ 8.5649e-03, -6.7938e-03,  4.0852e-03,  ..., -3.0715e-03,
            1.0122e-02, -7.3107e-03],
          [-1.2985e-02,  2.3871e-03,  1.6530e-03,  ..., -7.8724e-04,
           -2.1222e-03, -5.0700e-03]], requires_grad=True)),
 ('bias', Parameter containing:
  tensor([-0.0048,  0.0019,  0.0097, -0.0110,  0.0085], requires_grad=True))]

出力から確認できるようにweight、biasが格納されています。

net.named_parameters()と実行すると、各層の名前と重みが得られます。

これを活用していきます。

# 出力内容の確認
for name, param in net.named_parameters():
    print('name : ', name)
# Output
name :  features.0.weight
name :  features.0.bias
name :  features.2.weight
name :  features.2.bias

~~~~~~

name :  classifier.3.weight
name :  classifier.3.bias
name :  classifier.6.weight
name :  classifier.6.bias

各層に設定を適用するために、各層の重みを別のリストに格納していきます。

# featureモジュール
params_to_update_1 = []
# classifierモジュール(後半)
params_to_update_2 = []
# classifierモジュール(付け替えた層)
params_to_update_3 = []

# 学習させる層のパラメータ名を指定
update_param_names_1 = ['features']
update_param_names_2 = ['classifier.0.weight', 'classifier.0.bias',
                        'classifier.3.weight', 'classifier.3.bias']
update_param_names_3 = ['classifier.6.weight', 'classifier.6.bias']

# パラメータごとに各リストに格納
for name, param in net.named_parameters():

    if update_param_names_1[0] in name:
        param.requires_grad = True
        params_to_update_1.append(param)
        print("params_to_update_1に格納:", name)
    
    elif name in update_param_names_2:
        param.requires_grad = True
        params_to_update_2.append(param)
        print("params_to_update_2に格納:", name)
    
    elif name in update_param_names_3:
        param.requires_grad = True
        params_to_update_3.append(param)
        print("params_to_update_3に格納:", name)
# Output
params_to_update_1に格納: features.0.weight
params_to_update_1に格納: features.0.bias
params_to_update_1に格納: features.2.weight
params_to_update_1に格納: features.2.bias
params_to_update_1に格納: features.5.weight
params_to_update_1に格納: features.5.bias
params_to_update_1に格納: features.7.weight
params_to_update_1に格納: features.7.bias
params_to_update_1に格納: features.10.weight
params_to_update_1に格納: features.10.bias
params_to_update_1に格納: features.12.weight
params_to_update_1に格納: features.12.bias
params_to_update_1に格納: features.14.weight
params_to_update_1に格納: features.14.bias
params_to_update_1に格納: features.17.weight
params_to_update_1に格納: features.17.bias
params_to_update_1に格納: features.19.weight
params_to_update_1に格納: features.19.bias
params_to_update_1に格納: features.21.weight
params_to_update_1に格納: features.21.bias
params_to_update_1に格納: features.24.weight
params_to_update_1に格納: features.24.bias
params_to_update_1に格納: features.26.weight
params_to_update_1に格納: features.26.bias
params_to_update_1に格納: features.28.weight
params_to_update_1に格納: features.28.bias
params_to_update_2に格納: classifier.0.weight
params_to_update_2に格納: classifier.0.bias
params_to_update_2に格納: classifier.3.weight
params_to_update_2に格納: classifier.3.bias
params_to_update_3に格納: classifier.6.weight
params_to_update_3に格納: classifier.6.bias

最後に最適化手法を定義します。

ゼロからモデルのすべての層の重みを学習する場合、net.parameters()で一括で指定しますが、ファインチューニングでは、上の処理で振り分けた層に対してそれぞれの学習率を設定します。

次のような感じで設定することができます。

optimizer = optim.SGD([
    {'params': params_to_update_1, 'lr': 1e-4},
    {'params': params_to_update_2, 'lr': 5e-4},
    {'params': params_to_update_3, 'lr': 1e-3},
], momentum=0.9)



5. 学習と予測

5.1. 学習の実行

PyTorchでは学習時推論時でネットワークのモードを分ける必要があります。

「net.train()」「net.eval()」でそれぞれのモードを分ける処理を書いています。

# エポック数
num_epochs = 30

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch + 1, num_epochs))
    print('-------------')
    
    for phase in ['train', 'valid']:
        if phase == 'train':
            # 学習モードに設定
            net.train()
        else:
            # 訓練モードに設定
            net.eval()
            
        # epochの損失和
        epoch_loss = 0.0
        # epochの正解数
        epoch_corrects = 0
        
        for inputs, labels in dataloaders_dict[phase]:

            # optimizerを初期化
            optimizer.zero_grad()
            
            # 学習時のみ勾配を計算させる設定にする
            with torch.set_grad_enabled(phase == 'train'):
                
                outputs = net(inputs)
                
                # 損失を計算
                loss = criterion(outputs, labels)
                
                # ラベルを予測
                _, preds = torch.max(outputs, 1)
                
                # 訓練時は逆伝搬の計算
                if phase == 'train':
                    # 逆伝搬の計算
                    loss.backward()
                    
                    # パラメータ更新
                    optimizer.step()
                    
                # イテレーション結果の計算
                # lossの合計を更新
                # PyTorchの仕様上各バッチ内での平均のlossが計算される。
                # データ数を掛けることで平均から合計に変換をしている。
                # 損失和は「全データの損失/データ数」で計算されるため、
                # 平均のままだと損失和を求めることができないため。
                epoch_loss += loss.item() * inputs.size(0)
                
                # 正解数の合計を更新
                epoch_corrects += torch.sum(preds == labels.data)

        # epochごとのlossと正解率を表示
        epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
        epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
# Output
Epoch 1/30
-------------
train Loss: 1.2237 Acc: 0.5503
valid Loss: 0.4901 Acc: 0.9492
Epoch 2/30
-------------
train Loss: 0.2875 Acc: 0.9641
valid Loss: 0.1568 Acc: 0.9887
Epoch 3/30
-------------
train Loss: 0.1230 Acc: 0.9698
valid Loss: 0.0964 Acc: 0.9887
Epoch 4/30
-------------
train Loss: 0.0848 Acc: 0.9828
valid Loss: 0.0766 Acc: 0.9887
Epoch 5/30
-------------
train Loss: 0.0603 Acc: 0.9871
valid Loss: 0.0649 Acc: 0.9887
Epoch 6/30
-------------
train Loss: 0.0553 Acc: 0.9842
valid Loss: 0.0577 Acc: 0.9887
Epoch 7/30
-------------
train Loss: 0.0443 Acc: 0.9957
valid Loss: 0.0524 Acc: 0.9887
Epoch 8/30
-------------
train Loss: 0.0406 Acc: 0.9928
valid Loss: 0.0489 Acc: 0.9887
Epoch 9/30
-------------
train Loss: 0.0364 Acc: 0.9928
valid Loss: 0.0460 Acc: 0.9887
Epoch 10/30
-------------
train Loss: 0.0282 Acc: 0.9971
valid Loss: 0.0445 Acc: 0.9887

~~~~

以上でPyTorchでのモデル化の流れが完了です。

5.2. 予測の実行

あとは、学習したモデルで推論してやれば予測結果を取得できます。

# 予測用のダミーデータ
x = torch.randn([1, 3, 300, 300])
preds = net(x)



6. 所感

今回はファインチューニングを使用して画像分類モデルを実装しました。

かなり高い精度で予測ができているようです。

今回は30回程学習させましたが、10回目のループくらいから若干ですが過学習しているように見受けられます。

ループを回す回数は調整が必要かと思われます。




7. 全体のコード

最後に全体のコードをまとめて載せておきます。

# ライブラリの読み込み
import os
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms

# 学習データ、検証データへの分割
def make_filepath_list():
    """
    学習データ、検証データそれぞれのファイルへのパスを格納したリストを返す
    
    Returns
    -------
    train_file_list: list
        学習データファイルへのパスを格納したリスト
    valid_file_list: list
        検証データファイルへのパスを格納したリスト
    """
    train_file_list = []
    valid_file_list = []

    for top_dir in os.listdir('./Images/'):
        file_dir = os.path.join('./Images/', top_dir)
        file_list = os.listdir(file_dir)

        # 各犬種ごとに8割を学習データ、2割を検証データとする
        num_data = len(file_list)
        num_split = int(num_data * 0.8)

        train_file_list += [os.path.join('./Images', top_dir, file).replace('\\', '/') for file in file_list[:num_split]]
        valid_file_list += [os.path.join('./Images', top_dir, file).replace('\\', '/') for file in file_list[num_split:]]
    
    return train_file_list, valid_file_list

# 前処理クラス
class ImageTransform(object):
    """
    入力画像の前処理クラス
    画像のサイズをリサイズする
    
    Attributes
    ----------
    resize: int
        リサイズ先の画像の大きさ
    mean: (R, G, B)
        各色チャンネルの平均値
    std: (R, G, B)
        各色チャンネルの標準偏差
    """
    def __init__(self, resize, mean, std):
        self.data_trasnform = {
            'train': transforms.Compose([
                # データオーグメンテーション
                transforms.RandomHorizontalFlip(),
                # 画像をresize×resizeの大きさに統一する
                transforms.Resize((resize, resize)),
                # Tensor型に変換する
                transforms.ToTensor(),
                # 色情報の標準化をする
                transforms.Normalize(mean, std)
            ]),
            'valid': transforms.Compose([
                # 画像をresize×resizeの大きさに統一する
                transforms.Resize((resize, resize)),
                # Tensor型に変換する
                transforms.ToTensor(),
                # 色情報の標準化をする
                transforms.Normalize(mean, std)
            ])
        }
    
    def __call__(self, img, phase='train'):
        return self.data_trasnform[phase](img)
    
# Datasetクラス
class DogDataset(data.Dataset):
    """
    犬種のDataseクラス。
    PyTorchのDatasetクラスを継承させる。
    
    Attrbutes
    ---------
    file_list: list
        画像のファイルパスを格納したリスト
    classes: list
        犬種のラベル名
    transform: object
        前処理クラスのインスタンス
    phase: 'train' or 'valid'
        学習か検証化を設定
    """
    def __init__(self, file_list, classes, transform=None, phase='train'):
        self.file_list = file_list
        self.transform = transform
        self.classes = classes
        self.phase = phase
    
    def __len__(self):
        """
        画像の枚数を返す
        """
        return len(self.file_list)
    
    def __getitem__(self, index):
        """
        前処理した画像データのTensor形式のデータとラベルを取得
        """
        # 指定したindexの画像を読み込む
        img_path = self.file_list[index]
        img = Image.open(img_path)
        
        # 画像の前処理を実施
        img_transformed = self.transform(img, self.phase)
        
        # 画像ラベルをファイル名から抜き出す
        label = self.file_list[index].split('/')[2][10:]
        
        # ラベル名を数値に変換
        label = self.classes.index(label)
        
        return img_transformed, label

# パラメータ
# クラス名
dog_classes = [
    'Chihuahua',  'Shih-Tzu',
    'borzoi', 'Great_Dane', 'pug'
]

# リサイズ先の画像サイズ
resize = 300

# 今回は簡易的に(0.5, 0.5, 0.5)で標準化
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

# バッチサイズの指定
batch_size = 64

# エポック数
num_epochs = 30

# 2. 前処理
# 学習データ、検証データのファイルパスを格納したリストを取得する
train_file_list, valid_file_list = make_filepath_list()

# 3. Datasetの作成
train_dataset = DogDataset(
    file_list=train_file_list, classes=dog_classes,
    transform=ImageTransform(resize, mean, std),
    phase='train'
)

valid_dataset = DogDataset(
    file_list=valid_file_list, classes=dog_classes,
    transform=ImageTransform(resize, mean, std),
    phase='valid'
)

# 4. DataLoaderの作成
train_dataloader = data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

valid_dataloader = data.DataLoader(
    valid_dataset, batch_size=32, shuffle=False)

# 辞書にまとめる
dataloaders_dict = {
    'train': train_dataloader, 
    'valid': valid_dataloader
}

# 5. ネットワークの定義
use_pretrained = True
net = models.vgg16(pretrained=use_pretrained)

net.classifier[6] = nn.Linear(in_features=4096, out_features=5)

# 6. 損失関数の定義
criterion = nn.CrossEntropyLoss()

# 7. 最適化手法の定義
# 各層の重みをリストに格納していく
# featureモジュール
params_to_update_1 = []
# classifierモジュール(後半)
params_to_update_2 = []
# classifierモジュール(付け替えた層)
params_to_update_3 = []

# 学習させる層のパラメータ名を指定
update_param_names_1 = ['features']
update_param_names_2 = ['classifier.0.weight', 'classifier.0.bias',
                        'classifier.3.weight', 'classifier.3.bias']
update_param_names_3 = ['classifier.6.weight', 'classifier.6.bias']

# パラメータごとに各リストに格納
for name, param in net.named_parameters():

    if update_param_names_1[0] in name:
        param.requires_grad = True
        params_to_update_1.append(param)
        print("params_to_update_1に格納:", name)
    
    elif name in update_param_names_2:
        param.requires_grad = True
        params_to_update_2.append(param)
        print("params_to_update_2に格納:", name)
    
    elif name in update_param_names_3:
        param.requires_grad = True
        params_to_update_3.append(param)
        print("params_to_update_3に格納:", name)

optimizer = optim.SGD([
    {'params': params_to_update_1, 'lr': 1e-4},
    {'params': params_to_update_2, 'lr': 5e-4},
    {'params': params_to_update_3, 'lr': 1e-3},
], momentum=0.9)

# 8. 学習・検証
for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch + 1, num_epochs))
    print('-------------')
    
    for phase in ['train', 'valid']:
        if phase == 'train':
            # 学習モードに設定
            net.train()
        else:
            # 訓練モードに設定
            net.eval()
            
        # epochの損失和
        epoch_loss = 0.0
        # epochの正解数
        epoch_corrects = 0
        
        for inputs, labels in dataloaders_dict[phase]:

            # optimizerを初期化
            optimizer.zero_grad()
            
            # 学習時のみ勾配を計算させる設定にする
            with torch.set_grad_enabled(phase == 'train'):
                
                outputs = net(inputs)
                
                # 損失を計算
                loss = criterion(outputs, labels)
                
                # ラベルを予測
                _, preds = torch.max(outputs, 1)
                
                # 訓練時は逆伝搬の計算
                if phase == 'train':
                    # 逆伝搬の計算
                    loss.backward()
                    
                    # パラメータ更新
                    optimizer.step()
                    
                # イテレーション結果の計算
                # lossの合計を更新
                # PyTorchの仕様上各バッチ内での平均のlossが計算される。
                # データ数を掛けることで平均から合計に変換をしている。
                # 損失和は「全データの損失/データ数」で計算されるため、
                # 平均のままだと損失和を求めることができないため。
                epoch_loss += loss.item() * inputs.size(0)
                
                # 正解数の合計を更新
                epoch_corrects += torch.sum(preds == labels.data)

        # epochごとのlossと正解率を表示
        epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
        epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))