PyTorch Lightningによる多クラス分類の実装

1. 概要

PyTorchはモデル構築の際に柔軟にコードを書くことができますが、自由度が高いがゆえに学習用のループ処理などが複雑になりがちです。

PyTorch Lightningを使えば、PyTorchで書いていた学習用のループ処理などを分離・自動化できるため取り回しが格段に良くなります。

今回の記事ではPyTorch Lightningを使って多クラス分類を実装していきたいと思います。

チュートリアルを見るといろいろな設定ができますが、すでにあるPyTorchの全体感を損ねないように学習時のループ処理をPyTorch Lightningに置き換える形で実装しました。

PyTorch Lightningを使用しないで多クラス分類を実装した記事もあるので、良ければ参考にしながら見ていただければと思います。

venoda.hatenablog.com

2. モデル化の流れ

PyTorchは次の流れでモデル化していけば大きく間違えることはないかと思います。

  1. 準備
    1. データ準備
    2. 前処理
    3. Datasetの作成
    4. DataLoaderの作成
  2. Lightningモジュールの定義
    1. ネットワークの定義
    2. 損失関数の定義
    3. 最適化手法の定義
  3. 学習・評価
    1. 学習処理の設定
    2. 学習と予測の実行

では流れに沿って実装していきたいと思います。

今回使用するライブラリです。

import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy



3. 準備

3.1. データ準備、前処理

まずはSklearnのirisデータを読み込み、学習データと検証データに分けます。

その後、PyTorchに入力できるようにTensor型にデータを変換します。

# データ読み込み
iris = datasets.load_iris()
data = iris['data']
target = iris['target']

# 学習データと検証データに分割
x_train, x_valid, y_train, y_valid = train_test_split(data, target, shuffle=True)

# 特徴量の標準化
scaler = StandardScaler()
scaler.fit(x_train)

x_train = scaler.transform(x_train)
x_valid = scaler.transform(x_valid)

# Tensor型に変換
# 学習に入れるときはfloat型 or long型になっている必要があるのここで変換してしまう
x_train = torch.from_numpy(x_train).float()
y_train = torch.from_numpy(y_train).long()
x_valid = torch.from_numpy(x_valid).float()
y_valid = torch.from_numpy(y_valid).long()

print('x_train : ', x_train.shape)
print('y_train : ', y_train.shape)
print('x_valid : ', x_valid.shape)
print('y_valid : ', y_valid.shape)
# Output
x_train :  torch.Size([112, 4])
y_train :  torch.Size([112])
x_valid :  torch.Size([38, 4])
y_valid :  torch.Size([38])

3.2. Datasetの作成

PyTorchのTensorDatasetを使って説明変数と目的変数をワンセットにしたDatasetを作成します。

train_dataset = TensorDataset(x_train, y_train)
valid_dataset = TensorDataset(x_valid, y_valid)

# 動作確認
# indexを指定すればデータを取り出すことができます。
index = 0
print(train_dataset.__getitem__(index)[0].size())
print(train_dataset.__getitem__(index)[1])
# Output
torch.Size([4])
tensor([1, 0, 0], dtype=torch.uint8)

3.3. DataLoaderの作成

バッチ処理を適用するためにDataLoaderを作成して、データセットをバッチ単位で取り出せるようにします。

batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# 動作確認
# こんな感じでバッチ単位で取り出す子ができます。
# イテレータに変換
batch_iterator = iter(train_dataloader)
# 1番目の要素を取り出す
inputs, labels = next(batch_iterator)
print(inputs.size())
print(labels.size())
# Output
torch.Size([32, 4])
torch.Size([32])



4. Lightningモジュールの定義

Lightningモジュールを定義するためには、pl.LightningModuleを継承させたクラスを作成します。

このLightningモジュールの中にネットワーク、損失関数、最適化関数などを定義していきます。

class Net(pl.LightningModule):
    # この中にネットワークの定義を記述していく
    # ~

4.1. ネットワークの定義

「__init__」の中に親クラスのコンストラクタの呼び出しと使用する層を記述して、「forward」の中に順伝播処理を記述していきます。

class Net(pl.LightningModule):
    
    # ネットワークで使用する層を記述する
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 50)
        self.fc2 = nn.Linear(50, 3)
    
    # 順伝播処理を記述する
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.softmax(x, dim=1)
        
        return x

net = Net()
print(net)
# Output
Net(
  (fc1): Linear(in_features=4, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=3, bias=True)
)

4.2. 損失関数の定義

Lightningモジュールの中のtraining_step関数に損失関数を定義していきます。

これは各バッチ内での処理に相当します。

各epochの終了時に精度を確認したいので、training_epoch_end関数も同時に定義しています。

損失関数はクロスエントロピー誤差を使用します。

class Net(pl.LightningModule):
    
    # ネットワークで使用する層を記述する
    def __init__(self):
        # ~~
    
    # 順伝播処理を記述する
    def forward(self, x):
        # ~~
       
    # 学習の際に使用する損失関数を記述している
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)

        return {'loss': loss, 'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)}
    
    # 各エポック終了時の処理を記述している
    def training_epoch_end(self, train_step_outputs):
        y_hat = torch.cat([val['y_hat'] for val in train_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in train_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in train_step_outputs]) / y_hat.size(0)

        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy(preds, y)

        self.log('train_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('train_acc', acc, prog_bar=True, on_epoch=True)

        print('-------- Current Epoch {} --------'.format(self.current_epoch + 1))
        print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loss, acc))
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        return {'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)}
    
    def validation_epoch_end(self, val_step_outputs):
        # x_hatを一つにまとめる
        y_hat = torch.cat([val['y_hat'] for val in val_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in val_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in val_step_outputs]) / y_hat.size(0)

        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy(preds, y)

        self.log('val_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('val_acc', acc, prog_bar=True, on_epoch=True)

        print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc))

関数の内容を少し詳しく解説します。

元のPyTorchのコードに照らしあわせると、それぞれ関数は次の位置に当たります。

training_step関数は各バッチ単位に相当し、training_epoch_end関数は各エポック単位に相当します。

model = Model()
model.train()
torch.set_grad_enabled(True)

for epoch in epochs:
    outputs = []
    for batch in data:
        x, y = batch                                            # traninig_step
        y_hat = model(x)                                        # traninig_step
        loss = loss(y_hat, x)                                   # traninig_step
        outputs.append({'loss': loss, 'y_hat', y_hat, 'y': y})  # traninig_step
        
        # 学習処理が続く...

    total_loss = outputs.mean()                                 # training_epoch_end

training_step関数では損失関数を定義し、lossを戻り値に指定しています。

この際にloss単体を戻り値に指定するか、lossというKeyを持つ辞書を返す必要があるようです。

つまりは、↓のように書くか、

def training_step(self, batch, batch_idx):
    # ~~
    return loss

または、↓のような感じにする必要があります。

def training_step(self, batch, batch_idx):
    # ~~
    return {'loss': loss, 'x_hat': x_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)}

各エポック終了時に、精度を測るためにtraining_step関数の戻り値に正解ラベルと予測結果を格納しています。

training_epoch_end関数では、train_step_outputsに各バッチでの結果がリスト形式に格納されているので、ループを回して精度を計算します。

lossの計算ときにPyTorchの仕様上各バッチ内での平均のlossが計算されます。

loss.item() * x.size()をすることで、lossにデータ数を掛けてlossの平均から合計に変換しています。

training_epoch_end関数内で「全データの損失 ÷ データ数」とすることで、モデル全体の損失和を計算しなおす処理を加えています。

validation_step関数validation_epoch_end関数は、評価時に使用する関数で基本的にはtraining_step関数training_epoch_end関数とやっていることは同じです。

4.3. 最適化手法の定義

Lightningモジュールの中のconfigure_optimizers関数に最適化手法を定義していきます。

Lightningモジュール自体にパラメータを持っているので、self.parameters()で最適関数に渡します。

class Net(nn.Module):
    
    # ネットワークで使用する層を記述する
    def __init__(self):
        # ~~
    
    # 順伝播処理を記述する
    def forward(self, x):
        # ~~
       
    # 学習の際に使用する損失関数を記述している
    def training_step(self, batch, batch_idx):
        # ~~
    
    # 各エポック終了時の処理を記述している
    def training_epoch_end(self, train_step_outputs):
        # ~~
        
    def validation_step(self, batch, batch_idx):
        # ~~

    def validation_epoch_end(self, val_step_outputs):
        # ~~
    
    # 最適化手法を記述する
    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=0.01)

        return optimizer



5. 学習・評価

ここからがPyTorchであれば、あれこれ学習に関連する処理を書くのですが、PyTorch Lightningであれば簡単にできてしまいます。

5.1. 学習処理の設定

Trainerに学習に必要なパラメータを指定します。

これだけで学習の設定が可能になります。

# ネットワーク作成
net = Net()

# EarlyStoppingの設定
es = pl.callbacks.EarlyStopping(monitor='val_loss')

trainer = pl.Trainer(
    max_epochs=30,
    callbacks=[es],
    # GPUを使用する場合
    # gpus=2
)

5.2 学習と予測の実行

学習の実行はtrainerに必要な変数を渡すだけです。

trainer.fit(
    net,   # ネットワーク
    train_dataloader=train_dataloader,   # 学習データ
    val_dataloaders=valid_dataloader,    # 検証データ
)
# Output
-------- Current Epoch 1 --------
train Loss: 1.1252 train Acc: 0.2679
valid Loss: 1.1136 valid Acc: 0.3158
-------- Current Epoch 2 --------
train Loss: 1.1209 train Acc: 0.2768
valid Loss: 1.1100 valid Acc: 0.3158
-------- Current Epoch 3 --------
train Loss: 1.1166 train Acc: 0.2768
valid Loss: 1.1064 valid Acc: 0.3158

~~~~~~

-------- Current Epoch 28 --------
train Loss: 0.9939 train Acc: 0.8929
valid Loss: 1.0050 valid Acc: 0.8684
-------- Current Epoch 29 --------
train Loss: 0.9893 train Acc: 0.8929
valid Loss: 1.0012 valid Acc: 0.8684
-------- Current Epoch 30 --------
train Loss: 0.9848 train Acc: 0.8929
valid Loss: 0.9973 valid Acc: 0.8684

Lightningモジュールは機能は追加されてはいますが、基本的には同じPyTorchのtorch.nn.Moduleなので、同じように操作することができます。

つまり、予測時はPyTorchと変わらない操作で予測することができます。

# 予測用のダミーデータ
x = torch.randn(10, 4)

# 予測の実行
preds = net(x)

以上でPyTorchでのモデル化の流れが完了です。




6. 全体のコード

最後に全体のコードをのせておきます。

# ライブラリ
import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

# データ準備、前処理
# データ読み込み
iris = datasets.load_iris()
data = iris['data']
target = iris['target']

# 学習データと検証データに分割
x_train, x_valid, y_train, y_valid = train_test_split(data, target, shuffle=True)

# 特徴量の標準化
scaler = StandardScaler()
scaler.fit(x_train)

x_train = scaler.transform(x_train)
x_valid = scaler.transform(x_valid)

# Tensor型に変換
# 学習に入れるときはfloat型になっている必要があるのここで変換してしまう
x_train = torch.from_numpy(x_train).float()
y_train = torch.from_numpy(y_train).long()
x_valid = torch.from_numpy(x_valid).float()
y_valid = torch.from_numpy(y_valid).long()

print('x_train : ', x_train.shape)
print('y_train : ', y_train.shape)
print('x_valid : ', x_valid.shape)
print('y_valid : ', y_valid.shape)

# Datasetの作成
train_dataset = TensorDataset(x_train, y_train)
valid_dataset = TensorDataset(x_valid, y_valid)

# 動作確認
# indexを指定すればデータを取り出すことができます。
index = 0
print(train_dataset.__getitem__(index)[0].size())
print(train_dataset.__getitem__(index)[1])

# DataLoaderの作成
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# 動作確認
# こんな感じでバッチ単位で取り出す子ができます。
# イテレータに変換
batch_iterator = iter(train_dataloader)
# 1番目の要素を取り出す
inputs, labels = next(batch_iterator)
print(inputs.size())
print(labels.size())

# Lightningモジュール
class Net(pl.LightningModule):
    
    # ネットワークで使用する層を記述する
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 50)
        self.fc2 = nn.Linear(50, 3)
    
    # 順伝播処理を記述する
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.softmax(x, dim=1)
        
        return x

    # 学習の際に使用する損失関数を記述している
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)

        return {'loss': loss, 'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)}
    
    # 各エポック終了時の処理を記述している
    def training_epoch_end(self, train_step_outputs):
        y_hat = torch.cat([val['y_hat'] for val in train_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in train_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in train_step_outputs]) / y_hat.size(0)

        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy(preds, y)

        self.log('train_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('train_acc', acc, prog_bar=True, on_epoch=True)

        print('-------- Current Epoch {} --------'.format(self.current_epoch + 1))
        print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loss, acc))
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        return {'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)}
    
    def validation_epoch_end(self, val_step_outputs):
        # x_hatを一つにまとめる
        y_hat = torch.cat([val['y_hat'] for val in val_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in val_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in val_step_outputs]) / y_hat.size(0)

        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy(preds, y)

        self.log('val_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('val_acc', acc, prog_bar=True, on_epoch=True)

        print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc))
        
    # 最適化手法を記述する
    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=0.01)

        return optimizer

# Lightningモジュールの作成
net = Net()

# 学習処理の設定
# EarlyStoppingの設定
es = pl.callbacks.EarlyStopping(monitor='val_loss')

trainer = pl.Trainer(
    max_epochs=30,
    callbacks=[es],
    # GPUを使用する場合
    # gpus=2
)

# 学習の実行
trainer.fit(
    net,   # ネットワーク
    train_dataloader=train_dataloader,   # 学習データ
    val_dataloaders=valid_dataloader,    # 検証データ
)