PyTorch Lightningによる画像分類の実装

1. 概要

PyTorchはモデル構築の際に柔軟にコードを書くことができますが、自由度が高いがゆえに学習用のループ処理などが複雑になりがちです。

PyTorch Lightningを使えば、PyTorchで書いていた学習用のループ処理などを分離・自動化できるため取り回しが格段に良くなります。

今回の記事ではPyTorch Lightningを使って画像分類を実装していきたいと思います。

学習済みモデルを使わずに、自分で定義したモデルを使用して実装しています。

チュートリアルを見るといろいろな設定ができますが、すでにあるPyTorchの全体感を損ねないように学習時のループ処理をPyTorch Lightningに置き換える形で実装しました。

PyTorch Lightningを使用しないで画像分類を実装した記事もあるので、良ければ参考にしながら見ていただければと思います。

venoda.hatenablog.com

2. モデル化の流れ

PyTorch Lightningは次の流れでモデル化していけば大きく間違えることはないかと思います。

  1. 準備
    1. データ準備
    2. 前処理
    3. Datasetの作成
    4. DataLoaderの作成
  2. Lightningモジュールの定義
    1. ネットワークの定義
    2. 損失関数の定義
    3. 最適化手法の定義
  3. 学習・評価
    1. 学習処理の設定
    2. 学習と予測の実行

では流れに沿って実装していきたいと思います。



3. 準備

3.1 データ準備

今回使用するデータは、犬のデータセットを使用して画像分類を試みます。

まずは下記のサイトにアクセスして画像データをダウンロードします。

http://vision.stanford.edu/aditya86/ImageNetDogs/

f:id:venoda:20201018014456p:plain

「Images」のリンクをクリックすれば、画像データのダウンロードが開始されます。

ダウンロードしたデータを解凍すると、120種類の犬種がディレクトリ別に格納されています。

すべてのデータを使うと計算リソースを必要とするので、次の5犬種に絞ってモデルを構築していきます。

・チワワ(Chihuahua)

シーズー(Shih-Tzu)

ボルゾイ(borzoi)

・パグ(pug)

グレートデン(Great-Dane)

ディレクトリ構成は下記のようにしてあります。

├── Images  # 画像データ
│   ├── n02085620-Chihuahua
│   ├── n02086240-Shih-Tzu
│   ├── n02090622-borzoi
│   ├── n02109047-Great_Dane
│   └── n02110958-pug
└── PyTorch Lightningによる画像分類の実装.ipynb  # 実装スクリプト

最後に今回使用するライブラリです。

# ライブラリの読み込み
import os
from PIL import Image

import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

import matplotlib.pyplot as plt
%matplotlib inline

3.2 前処理

Datasetにデータを渡す準備として、前処理をしていきます。

まずは、用意したデータセットを学習データと検証データに分割します。

格納されているデータ数は各犬種ごとに異なるので、それぞれ学習データを80%、検証データを20%にします。

そして、学習データ、検証データそれぞれのファイルパスを格納したリストを用意します。

def make_filepath_list():
    """
    学習データ、検証データそれぞれのファイルへのパスを格納したリストを返す
    
    Returns
    -------
    train_file_list: list
        学習データファイルへのパスを格納したリスト
    valid_file_list: list
        検証データファイルへのパスを格納したリスト
    """
    train_file_list = []
    valid_file_list = []

    for top_dir in os.listdir('./Images/'):
        file_dir = os.path.join('./Images/', top_dir)
        file_list = os.listdir(file_dir)

        # 各犬種ごとに8割を学習データ、2割を検証データとする
        num_data = len(file_list)
        num_split = int(num_data * 0.8)

        train_file_list += [os.path.join('./Images', top_dir, file).replace('\\', '/') for file in file_list[:num_split]]
        valid_file_list += [os.path.join('./Images', top_dir, file).replace('\\', '/') for file in file_list[num_split:]]
    
    return train_file_list, valid_file_list

# 画像データへのファイルパスを格納したリストを取得する
train_file_list, valid_file_list = make_filepath_list()

print('学習データ数 : ', len(train_file_list))
# 先頭3件だけ表示
print(train_file_list[:3])

print('検証データ数 : ', len(valid_file_list))
# 先頭3件だけ表示
print(valid_file_list[:3])
# Output
学習データ数 :  696
['./Images/n02085620-Chihuahua/n02085620_10074.jpg', './Images/n02085620-Chihuahua/n02085620_10131.jpg', './Images/n02085620-Chihuahua/n02085620_10621.jpg']
検証データ数 :  177
['./Images/n02085620-Chihuahua/n02085620_588.jpg', './Images/n02085620-Chihuahua/n02085620_5927.jpg', './Images/n02085620-Chihuahua/n02085620_6295.jpg']

次に画像データに対する前処理(resizeなど)の処理を記述したクラスを作成します。

このクラスに画像データを通せば、指定した前処理を施したデータがえられます。

{train: ~, valid: ~}としている理由は、学習時と推論時で実施する前処理を変えるためです。

学習時にはモデルの性能を高めるために、データオーグメンテーション用の前処理を加えていますが、推論時にはデータオーグメンテーションする必要がないため、画像のサイズを整えるなどの処理だけにとどめています。

class ImageTransform(object):
    """
    入力画像の前処理クラス
    画像のサイズをリサイズする
    
    Attributes
    ----------
    resize: int
        リサイズ先の画像の大きさ
    mean: (R, G, B)
        各色チャンネルの平均値
    std: (R, G, B)
        各色チャンネルの標準偏差
    """
    def __init__(self, resize, mean, std):
        self.data_trasnform = {
            'train': transforms.Compose([
                # データオーグメンテーション
                transforms.RandomHorizontalFlip(),
                # 画像をresize×resizeの大きさに統一する
                transforms.Resize((resize, resize)),
                # Tensor型に変換する
                transforms.ToTensor(),
                # 色情報の標準化をする
                transforms.Normalize(mean, std)
            ]),
            'valid': transforms.Compose([
                # 画像をresize×resizeの大きさに統一する
                transforms.Resize((resize, resize)),
                # Tensor型に変換する
                transforms.ToTensor(),
                # 色情報の標準化をする
                transforms.Normalize(mean, std)
            ])
        }
    
    def __call__(self, img, phase='train'):
        return self.data_trasnform[phase](img)

# 動作確認
img = Image.open('./Images/n02085620-Chihuahua/n02085620_199.jpg')

# リサイズ先の画像サイズ
resize = 300

# 今回は簡易的に(0.5, 0.5, 0.5)で標準化
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

transform = ImageTransform(resize, mean, std)
img_transformed = transform(img, 'train')

plt.imshow(img)
plt.show()

plt.imshow(img_transformed.numpy().transpose((1, 2, 0)))
plt.show()

実際にImageTransformクラスを実行してみると、画像がリサイズされ色も標準化されていることがわかります。

加えて、phaseパラメータをtrainに設定して何回か実行すると、画像が左右反転したデータも出力されることがわかります。

3.3 Datasetの作成

PyTorchのDatasetクラスを継承させたクラスを作ります。

処理を記述する箇所は「__len__」「__getitem__」の二つです。

「__len__」にはDatasetに含まれるデータ数を返す処理を記述します。

「__getitem__」にはindex番号を引数にとり、学習データとラベルデータを返す処理を記述します。

class DogDataset(data.Dataset):
    """
    犬種のDataseクラス。
    PyTorchのDatasetクラスを継承させる。
    
    Attrbutes
    ---------
    file_list: list
        画像のファイルパスを格納したリスト
    classes: list
        犬種のラベル名
    transform: object
        前処理クラスのインスタンス
    phase: 'train' or 'valid'
        学習か検証化を設定
    """
    def __init__(self, file_list, classes, transform=None, phase='train'):
        self.file_list = file_list
        self.transform = transform
        self.classes = classes
        self.phase = phase
    
    def __len__(self):
        """
        画像の枚数を返す
        """
        return len(self.file_list)
    
    def __getitem__(self, index):
        """
        前処理した画像データのTensor形式のデータとラベルを取得
        """
        # 指定したindexの画像を読み込む
        img_path = self.file_list[index]
        img = Image.open(img_path)
        
        # 画像の前処理を実施
        img_transformed = self.transform(img, self.phase)
        
        # 画像ラベルをファイル名から抜き出す
        label = self.file_list[index].split('/')[2][10:]
        
        # ラベル名を数値に変換
        label = self.classes.index(label)
        
        return img_transformed, label


# 動作確認
# クラス名
dog_classes = [
    'Chihuahua',  'Shih-Tzu',
    'borzoi', 'Great_Dane', 'pug'
]

# リサイズ先の画像サイズ
resize = 300

# 今回は簡易的に(0.5, 0.5, 0.5)で標準化
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

# Datasetの作成
train_dataset = DogDataset(
    file_list=train_file_list, classes=dog_classes,
    transform=ImageTransform(resize, mean, std),
    phase='train'
)

valid_dataset = DogDataset(
    file_list=valid_file_list, classes=dog_classes,
    transform=ImageTransform(resize, mean, std),
    phase='valid'
)

index = 0
print(train_dataset.__getitem__(index)[0].size())
print(train_dataset.__getitem__(index)[1])
# Output
torch.Size([3, 300, 300])
0

3.4 DataLoaderの作成

バッチ処理を適用するためにDataLoaderを作成して、データセットをバッチ単位で取り出せるようにします。

# バッチサイズの指定
batch_size = 64

# DataLoaderを作成
train_dataloader = data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

valid_dataloader = data.DataLoader(
    valid_dataset, batch_size=32, shuffle=False)

# 動作確認
# イテレータに変換
batch_iterator = iter(train_dataloader)

# 1番目の要素を取り出す
inputs, labels = next(batch_iterator)

print(inputs.size())
print(labels)
# Output
torch.Size([32, 3, 300, 300])
tensor([2, 0, 2, 4, 4, 3, 4, 4, 3, 3, 1, 1, 3, 4, 4, 0, 1, 4, 0, 4, 3, 1, 1, 0,
        4, 4, 3, 0, 1, 1, 1, 3])



4. Lighgningモジュールの定義

Lightningモジュールを定義するためには、pl.LightningModuleを継承させたクラスを作成します。

このLightningモジュールの中にネットワーク、損失関数、最適化関数などを定義していきます。

class Net(pl.LightningModule):
    # この中にネットワークの定義を記述していく
    # ~

4.1. ネットワークの定義

「__init__」の中に親クラスのコンストラクタの呼び出しと使用する層を記述して、「forward」の中に順伝播処理を記述していきます。

class Net(pl.LightningModule):
    
    # ネットワークで使用する層を記述する
    def __init__(self):
        super().__init__()
        self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(in_features=128 * 75 * 75, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=5)
    
    # 順伝播処理を記述する
    def forward(self, x):
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))
        x = self.pool1(x)
        
        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))
        x = self.pool2(x)

        x = x.view(-1, 128 * 75 * 75)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.softmax(x, dim=1)
        
        return x

net = Net()
print(net)
# Output
Net(
  (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=720000, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=5, bias=True)
)

4.2 損失関数の定義

Lightningモジュールの中のtraining_step関数に損失関数を定義していきます。

これは各バッチ内での処理に相当します。

各epochの終了時に精度を確認したいので、training_epoch_end関数も同時に定義しています。

損失関数はクロスエントロピー誤差を使用します。

class Net(pl.LightningModule):
    
    # ネットワークで使用する層を記述する
    def __init__(self):
        # ~~
    
    # 順伝播処理を記述する
    def forward(self, x):
        # ~~
       
    # 学習の際に使用する損失関数を記述している
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)

        return {'loss': loss, 'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)}
    
    # 各エポック終了時の処理を記述している
    def training_epoch_end(self, train_step_outputs):
        y_hat = torch.cat([val['y_hat'] for val in train_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in train_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in train_step_outputs]) / y_hat.size(0)

        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy(preds, y)

        self.log('train_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('train_acc', acc, prog_bar=True, on_epoch=True)

        print('-------- Current Epoch {} --------'.format(self.current_epoch + 1))
        print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loss, acc))
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        return {'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)}
    
    def validation_epoch_end(self, val_step_outputs):
        # x_hatを一つにまとめる
        y_hat = torch.cat([val['y_hat'] for val in val_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in val_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in val_step_outputs]) / y_hat.size(0)

        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy(preds, y)

        self.log('val_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('val_acc', acc, prog_bar=True, on_epoch=True)

        print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc))

関数の内容を少し詳しく解説します。

元のPyTorchのコードに照らしあわせると、それぞれ関数は次の位置に当たります。

training_step関数は各バッチ単位に相当し、training_epoch_end関数は各エポック単位に相当します。

model = Model()
model.train()
torch.set_grad_enabled(True)

for epoch in epochs:
    outputs = []
    for batch in data:
        x, y = batch                                            # traninig_step
        y_hat = model(x)                                        # traninig_step
        loss = loss(y_hat, x)                                   # traninig_step
        outputs.append({'loss': loss, 'y_hat', y_hat, 'y': y})  # traninig_step
        
        # 学習処理が続く...

    total_loss = outputs.mean()                                 # training_epoch_end

training_step関数では損失関数を定義し、lossを戻り値に指定しています。

この際にloss単体を戻り値に指定するか、lossというKeyを持つ辞書を返す必要があるようです。

つまりは、↓のように書くか、

def training_step(self, batch, batch_idx):
    # ~~
    return loss

または、↓のような感じにする必要があります。

def training_step(self, batch, batch_idx):
    # ~~
    return {'loss': loss, 'x_hat': x_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)}

各エポック終了時に、精度を測るためにtraining_step関数の戻り値に正解ラベルと予測結果を格納しています。

training_epoch_end関数では、train_step_outputsに各バッチでの結果がリスト形式に格納されているので、ループを回して精度を計算します。

lossの計算ときにPyTorchの仕様上各バッチ内での平均のlossが計算されます。

loss.item() * x.size()をすることで、lossにデータ数を掛けてlossの平均から合計に変換しています。

training_epoch_end関数内で「全データの損失 ÷ データ数」とすることで、モデル全体の損失和を計算しなおす処理を加えています。

validation_step関数validation_epoch_end関数は、評価時に使用する関数で基本的にはtraining_step関数training_epoch_end関数とやっていることは同じです。

4.3. 最適化手法の定義

Lightningモジュールの中のconfigure_optimizers関数に最適化手法を定義していきます。

Lightningモジュール自体にパラメータを持っているので、self.parameters()で最適関数に渡します。

class Net(pl.LightningModule):
    
    # ネットワークで使用する層を記述する
    def __init__(self):
        # ~~
    
    # 順伝播処理を記述する
    def forward(self, x):
        # ~~
       
    # 学習の際に使用する損失関数を記述している
    def training_step(self, batch, batch_idx):
        # ~~
    
    # 各エポック終了時の処理を記述している
    def training_epoch_end(self, train_step_outputs):
        # ~~
        
    def validation_step(self, batch, batch_idx):
        # ~~

    def validation_epoch_end(self, val_step_outputs):
        # ~~
    
    # 最適化手法を記述する
    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=0.01)

        return optimizer



5. 学習・評価

ここからがPyTorchであれば、あれこれ学習に関連する処理を書くのですが、PyTorch Lightningであれば簡単にできてしまいます。

5.1. 学習処理の設定

Trainerに学習に必要なパラメータを指定します。

これだけで学習の設定が可能になります。

# ネットワーク作成
net = Net()

# EarlyStoppingの設定
es = pl.callbacks.EarlyStopping(monitor='val_loss')

trainer = pl.Trainer(
    max_epochs=30,
    callbacks=[es],
    # GPUを使用する場合
    # gpus=2
)

5.2 学習と予測の実行

学習の実行はtrainerに必要な変数を渡すだけです。

trainer.fit(
    net,   # ネットワーク
    train_dataloader=train_dataloader,   # 学習データ
    val_dataloaders=valid_dataloader,    # 検証データ
)
# Output
-------- Current Epoch 1 --------
train Loss: 1.6096 train Acc: 0.2184
valid Loss: 1.6090 valid Acc: 0.2429
-------- Current Epoch 2 --------
train Loss: 1.6087 train Acc: 0.2457
valid Loss: 1.6083 valid Acc: 0.2429
-------- Current Epoch 3 --------
train Loss: 1.6079 train Acc: 0.2457
valid Loss: 1.6076 valid Acc: 0.2429

~~~~~~

-------- Current Epoch 28 --------
train Loss: 1.5691 train Acc: 0.3132
valid Loss: 1.5798 valid Acc: 0.2994
-------- Current Epoch 29 --------
train Loss: 1.5656 train Acc: 0.3147
valid Loss: 1.5777 valid Acc: 0.2881
-------- Current Epoch 30 --------
train Loss: 1.5613 train Acc: 0.3247
valid Loss: 1.5760 valid Acc: 0.2881

Lightningモジュールは機能は追加されてはいますが、基本的には同じPyTorchのtorch.nn.Moduleなので、同じように操作することができます。

つまり、予測時はPyTorchと変わらない操作で予測することができます。

# 予測用のダミーデータ
x = torch.randn(3, 3, 300, 300)

# 予測の実行
preds = net(x)

以上でPyTorchでのモデル化の流れが完了です。

今回はパラメータチューニング、転移学習、ファインチューニングなどは実施せずに画像分類モデルを作成しました。

予測精度の良いモデルは得られませんでした。




6. 全体のコード

最後に全体のコードをまとめて載せておきます。

# ライブラリの読み込み
import os
from PIL import Image

import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms

import matplotlib.pyplot as plt
%matplotlib inline

# 学習データ、検証データへの分割
def make_filepath_list():
    """
    学習データ、検証データそれぞれのファイルへのパスを格納したリストを返す
    
    Returns
    -------
    train_file_list: list
        学習データファイルへのパスを格納したリスト
    valid_file_list: list
        検証データファイルへのパスを格納したリスト
    """
    train_file_list = []
    valid_file_list = []

    for top_dir in os.listdir('./Images/'):
        file_dir = os.path.join('./Images/', top_dir)
        file_list = os.listdir(file_dir)

        # 各犬種ごとに8割を学習データ、2割を検証データとする
        num_data = len(file_list)
        num_split = int(num_data * 0.8)

        train_file_list += [os.path.join('./Images', top_dir, file).replace('\\', '/') for file in file_list[:num_split]]
        valid_file_list += [os.path.join('./Images', top_dir, file).replace('\\', '/') for file in file_list[num_split:]]
    
    return train_file_list, valid_file_list

# 前処理クラス
class ImageTransform(object):
    """
    入力画像の前処理クラス
    画像のサイズをリサイズする
    
    Attributes
    ----------
    resize: int
        リサイズ先の画像の大きさ
    mean: (R, G, B)
        各色チャンネルの平均値
    std: (R, G, B)
        各色チャンネルの標準偏差
    """
    def __init__(self, resize, mean, std):
        self.data_trasnform = {
            'train': transforms.Compose([
                # データオーグメンテーション
                transforms.RandomHorizontalFlip(),
                # 画像をresize×resizeの大きさに統一する
                transforms.Resize((resize, resize)),
                # Tensor型に変換する
                transforms.ToTensor(),
                # 色情報の標準化をする
                transforms.Normalize(mean, std)
            ]),
            'valid': transforms.Compose([
                # 画像をresize×resizeの大きさに統一する
                transforms.Resize((resize, resize)),
                # Tensor型に変換する
                transforms.ToTensor(),
                # 色情報の標準化をする
                transforms.Normalize(mean, std)
            ])
        }
    
    def __call__(self, img, phase='train'):
        return self.data_trasnform[phase](img)
    
# Datasetクラス
class DogDataset(data.Dataset):
    """
    犬種のDataseクラス。
    PyTorchのDatasetクラスを継承させる。
    
    Attrbutes
    ---------
    file_list: list
        画像のファイルパスを格納したリスト
    classes: list
        犬種のラベル名
    transform: object
        前処理クラスのインスタンス
    phase: 'train' or 'valid'
        学習か検証化を設定
    """
    def __init__(self, file_list, classes, transform=None, phase='train'):
        self.file_list = file_list
        self.transform = transform
        self.classes = classes
        self.phase = phase
    
    def __len__(self):
        """
        画像の枚数を返す
        """
        return len(self.file_list)
    
    def __getitem__(self, index):
        """
        前処理した画像データのTensor形式のデータとラベルを取得
        """
        # 指定したindexの画像を読み込む
        img_path = self.file_list[index]
        img = Image.open(img_path)
        
        # 画像の前処理を実施
        img_transformed = self.transform(img, self.phase)
        
        # 画像ラベルをファイル名から抜き出す
        label = self.file_list[index].split('/')[2][10:]
        
        # ラベル名を数値に変換
        label = self.classes.index(label)
        
        return img_transformed, label

# Lightningモジュール
class Net(pl.LightningModule):
    
    # ネットワークで使用する層を記述する
    def __init__(self):
        super().__init__()
        self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(in_features=128 * 75 * 75, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=5)
    
    # 順伝播処理を記述する
    def forward(self, x):
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))
        x = self.pool1(x)
        
        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))
        x = self.pool2(x)

        x = x.view(-1, 128 * 75 * 75)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.softmax(x, dim=1)
        
        return x

    # 学習の際に使用する損失関数を記述している
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)

        return {'loss': loss, 'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)}
    
    # 各エポック終了時の処理を記述している
    def training_epoch_end(self, train_step_outputs):
        y_hat = torch.cat([val['y_hat'] for val in train_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in train_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in train_step_outputs]) / y_hat.size(0)

        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy(preds, y)

        self.log('train_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('train_acc', acc, prog_bar=True, on_epoch=True)

        print('-------- Current Epoch {} --------'.format(self.current_epoch + 1))
        print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loss, acc))
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        return {'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)}
    
    def validation_epoch_end(self, val_step_outputs):
        # x_hatを一つにまとめる
        y_hat = torch.cat([val['y_hat'] for val in val_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in val_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in val_step_outputs]) / y_hat.size(0)

        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy(preds, y)

        self.log('val_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('val_acc', acc, prog_bar=True, on_epoch=True)

        print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc))

    # 最適化手法を記述する
    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=0.01)

        return optimizer
    
# 各種パラメータの用意
# クラス名
dog_classes = [
    'Chihuahua',  'Shih-Tzu',
    'borzoi', 'Great_Dane', 'pug'
]

# リサイズ先の画像サイズ
resize = 300

# 今回は簡易的に(0.5, 0.5, 0.5)で標準化
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

# バッチサイズの指定
batch_size = 64

# エポック数
num_epochs = 30

# 前処理
# 学習データ、検証データのファイルパスを格納したリストを取得する
train_file_list, valid_file_list = make_filepath_list()

# Datasetの作成
train_dataset = DogDataset(
    file_list=train_file_list, classes=dog_classes,
    transform=ImageTransform(resize, mean, std),
    phase='train'
)

valid_dataset = DogDataset(
    file_list=valid_file_list, classes=dog_classes,
    transform=ImageTransform(resize, mean, std),
    phase='valid'
)

# DataLoaderの作成
train_dataloader = data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

valid_dataloader = data.DataLoader(
    valid_dataset, batch_size=32, shuffle=False)


# Lightningモジュールの作成
net = Net()

# 学習処理の設定
# EarlyStoppingの設定
es = pl.callbacks.EarlyStopping(monitor='val_loss')

trainer = pl.Trainer(
    max_epochs=30,
    callbacks=[es],
    # GPUを使用する場合
    # gpus=2
)

# 学習の実行
trainer.fit(
    net,   # ネットワーク
    train_dataloader=train_dataloader,   # 学習データ
    val_dataloaders=valid_dataloader,    # 検証データ
)