1. 概要
PyTorchはモデル構築の際に柔軟にコードを書くことができますが、自由度が高いがゆえに学習用のループ処理などが複雑になりがちです。
PyTorch Lightningを使えば、PyTorchで書いていた学習用のループ処理などを分離・自動化できるため取り回しが格段に良くなります。
今回の記事ではPyTorch Lightningを使って多クラス分類を実装していきたいと思います。
チュートリアルを見るといろいろな設定ができますが、すでにあるPyTorchの全体感を損ねないように学習時のループ処理をPyTorch Lightningに置き換える形で実装しました。
PyTorch Lightningを使用しないで多クラス分類を実装した記事もあるので、良ければ参考にしながら見ていただければと思います。
2. モデル化の流れ
PyTorchは次の流れでモデル化していけば大きく間違えることはないかと思います。
- 準備
- データ準備
- 前処理
- Datasetの作成
- DataLoaderの作成
- Lightningモジュールの定義
- ネットワークの定義
- 損失関数の定義
- 最適化手法の定義
- 学習・評価
- 学習処理の設定
- 学習と予測の実行
では流れに沿って実装していきたいと思います。
今回使用するライブラリです。
import pandas as pd from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import TensorDataset from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.metrics.functional import accuracy
3. 準備
3.1. データ準備、前処理
まずはSklearnのirisデータを読み込み、学習データと検証データに分けます。
その後、PyTorchに入力できるようにTensor型にデータを変換します。
# データ読み込み iris = datasets.load_iris() data = iris['data'] target = iris['target'] # 学習データと検証データに分割 x_train, x_valid, y_train, y_valid = train_test_split(data, target, shuffle=True) # 特徴量の標準化 scaler = StandardScaler() scaler.fit(x_train) x_train = scaler.transform(x_train) x_valid = scaler.transform(x_valid) # Tensor型に変換 # 学習に入れるときはfloat型 or long型になっている必要があるのここで変換してしまう x_train = torch.from_numpy(x_train).float() y_train = torch.from_numpy(y_train).long() x_valid = torch.from_numpy(x_valid).float() y_valid = torch.from_numpy(y_valid).long() print('x_train : ', x_train.shape) print('y_train : ', y_train.shape) print('x_valid : ', x_valid.shape) print('y_valid : ', y_valid.shape)
# Output x_train : torch.Size([112, 4]) y_train : torch.Size([112]) x_valid : torch.Size([38, 4]) y_valid : torch.Size([38])
3.2. Datasetの作成
PyTorchのTensorDatasetを使って説明変数と目的変数をワンセットにしたDatasetを作成します。
train_dataset = TensorDataset(x_train, y_train) valid_dataset = TensorDataset(x_valid, y_valid) # 動作確認 # indexを指定すればデータを取り出すことができます。 index = 0 print(train_dataset.__getitem__(index)[0].size()) print(train_dataset.__getitem__(index)[1])
# Output torch.Size([4]) tensor([1, 0, 0], dtype=torch.uint8)
3.3. DataLoaderの作成
バッチ処理を適用するためにDataLoaderを作成して、データセットをバッチ単位で取り出せるようにします。
batch_size = 32 train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) # 動作確認 # こんな感じでバッチ単位で取り出す子ができます。 # イテレータに変換 batch_iterator = iter(train_dataloader) # 1番目の要素を取り出す inputs, labels = next(batch_iterator) print(inputs.size()) print(labels.size())
# Output torch.Size([32, 4]) torch.Size([32])
4. Lightningモジュールの定義
Lightningモジュールを定義するためには、pl.LightningModule
を継承させたクラスを作成します。
このLightningモジュールの中にネットワーク、損失関数、最適化関数などを定義していきます。
class Net(pl.LightningModule): # この中にネットワークの定義を記述していく # ~
4.1. ネットワークの定義
「__init__」の中に親クラスのコンストラクタの呼び出しと使用する層を記述して、「forward」の中に順伝播処理を記述していきます。
class Net(pl.LightningModule): # ネットワークで使用する層を記述する def __init__(self): super().__init__() self.fc1 = nn.Linear(4, 50) self.fc2 = nn.Linear(50, 3) # 順伝播処理を記述する def forward(self, x): x = self.fc1(x) x = F.relu(x) x = self.fc2(x) x = F.softmax(x, dim=1) return x net = Net() print(net)
# Output Net( (fc1): Linear(in_features=4, out_features=50, bias=True) (fc2): Linear(in_features=50, out_features=3, bias=True) )
4.2. 損失関数の定義
Lightningモジュールの中のtraining_step関数
に損失関数を定義していきます。
これは各バッチ内での処理に相当します。
各epochの終了時に精度を確認したいので、training_epoch_end関数
も同時に定義しています。
損失関数はクロスエントロピー誤差を使用します。
class Net(pl.LightningModule): # ネットワークで使用する層を記述する def __init__(self): # ~~ # 順伝播処理を記述する def forward(self, x): # ~~ # 学習の際に使用する損失関数を記述している def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return {'loss': loss, 'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)} # 各エポック終了時の処理を記述している def training_epoch_end(self, train_step_outputs): y_hat = torch.cat([val['y_hat'] for val in train_step_outputs], dim=0) y = torch.cat([val['y'] for val in train_step_outputs], dim=0) epoch_loss = sum([val['batch_loss'] for val in train_step_outputs]) / y_hat.size(0) preds = torch.argmax(y_hat, dim=1) acc = accuracy(preds, y) self.log('train_loss', epoch_loss, prog_bar=True, on_epoch=True) self.log('train_acc', acc, prog_bar=True, on_epoch=True) print('-------- Current Epoch {} --------'.format(self.current_epoch + 1)) print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loss, acc)) def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return {'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)} def validation_epoch_end(self, val_step_outputs): # x_hatを一つにまとめる y_hat = torch.cat([val['y_hat'] for val in val_step_outputs], dim=0) y = torch.cat([val['y'] for val in val_step_outputs], dim=0) epoch_loss = sum([val['batch_loss'] for val in val_step_outputs]) / y_hat.size(0) preds = torch.argmax(y_hat, dim=1) acc = accuracy(preds, y) self.log('val_loss', epoch_loss, prog_bar=True, on_epoch=True) self.log('val_acc', acc, prog_bar=True, on_epoch=True) print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc))
関数の内容を少し詳しく解説します。
元のPyTorchのコードに照らしあわせると、それぞれ関数は次の位置に当たります。
training_step関数
は各バッチ単位に相当し、training_epoch_end関数
は各エポック単位に相当します。
model = Model() model.train() torch.set_grad_enabled(True) for epoch in epochs: outputs = [] for batch in data: x, y = batch # traninig_step y_hat = model(x) # traninig_step loss = loss(y_hat, x) # traninig_step outputs.append({'loss': loss, 'y_hat', y_hat, 'y': y}) # traninig_step # 学習処理が続く... total_loss = outputs.mean() # training_epoch_end
training_step関数
では損失関数を定義し、lossを戻り値に指定しています。
この際にloss単体を戻り値に指定するか、lossというKeyを持つ辞書を返す必要があるようです。
つまりは、↓のように書くか、
def training_step(self, batch, batch_idx): # ~~ return loss
または、↓のような感じにする必要があります。
def training_step(self, batch, batch_idx): # ~~ return {'loss': loss, 'x_hat': x_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)}
各エポック終了時に、精度を測るためにtraining_step関数
の戻り値に正解ラベルと予測結果を格納しています。
training_epoch_end関数
では、train_step_outputsに各バッチでの結果がリスト形式に格納されているので、ループを回して精度を計算します。
lossの計算ときにPyTorchの仕様上各バッチ内での平均のlossが計算されます。
loss.item() * x.size()
をすることで、lossにデータ数を掛けてlossの平均から合計に変換しています。
training_epoch_end関数
内で「全データの損失 ÷ データ数」とすることで、モデル全体の損失和を計算しなおす処理を加えています。
validation_step関数
とvalidation_epoch_end関数
は、評価時に使用する関数で基本的にはtraining_step関数
とtraining_epoch_end関数
とやっていることは同じです。
4.3. 最適化手法の定義
Lightningモジュールの中のconfigure_optimizers関数
に最適化手法を定義していきます。
Lightningモジュール自体にパラメータを持っているので、self.parameters()
で最適関数に渡します。
class Net(nn.Module): # ネットワークで使用する層を記述する def __init__(self): # ~~ # 順伝播処理を記述する def forward(self, x): # ~~ # 学習の際に使用する損失関数を記述している def training_step(self, batch, batch_idx): # ~~ # 各エポック終了時の処理を記述している def training_epoch_end(self, train_step_outputs): # ~~ def validation_step(self, batch, batch_idx): # ~~ def validation_epoch_end(self, val_step_outputs): # ~~ # 最適化手法を記述する def configure_optimizers(self): optimizer = optim.SGD(self.parameters(), lr=0.01) return optimizer
5. 学習・評価
ここからがPyTorchであれば、あれこれ学習に関連する処理を書くのですが、PyTorch Lightningであれば簡単にできてしまいます。
5.1. 学習処理の設定
Trainer
に学習に必要なパラメータを指定します。
これだけで学習の設定が可能になります。
# ネットワーク作成 net = Net() # EarlyStoppingの設定 es = pl.callbacks.EarlyStopping(monitor='val_loss') trainer = pl.Trainer( max_epochs=30, callbacks=[es], # GPUを使用する場合 # gpus=2 )
5.2 学習と予測の実行
学習の実行はtrainer
に必要な変数を渡すだけです。
trainer.fit( net, # ネットワーク train_dataloader=train_dataloader, # 学習データ val_dataloaders=valid_dataloader, # 検証データ )
# Output -------- Current Epoch 1 -------- train Loss: 1.1252 train Acc: 0.2679 valid Loss: 1.1136 valid Acc: 0.3158 -------- Current Epoch 2 -------- train Loss: 1.1209 train Acc: 0.2768 valid Loss: 1.1100 valid Acc: 0.3158 -------- Current Epoch 3 -------- train Loss: 1.1166 train Acc: 0.2768 valid Loss: 1.1064 valid Acc: 0.3158 ~~~~~~ -------- Current Epoch 28 -------- train Loss: 0.9939 train Acc: 0.8929 valid Loss: 1.0050 valid Acc: 0.8684 -------- Current Epoch 29 -------- train Loss: 0.9893 train Acc: 0.8929 valid Loss: 1.0012 valid Acc: 0.8684 -------- Current Epoch 30 -------- train Loss: 0.9848 train Acc: 0.8929 valid Loss: 0.9973 valid Acc: 0.8684
Lightningモジュールは機能は追加されてはいますが、基本的には同じPyTorchのtorch.nn.Moduleなので、同じように操作することができます。
つまり、予測時はPyTorchと変わらない操作で予測することができます。
# 予測用のダミーデータ x = torch.randn(10, 4) # 予測の実行 preds = net(x)
以上でPyTorchでのモデル化の流れが完了です。
6. 全体のコード
最後に全体のコードをのせておきます。
# ライブラリ import pandas as pd from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import TensorDataset from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.metrics.functional import accuracy # データ準備、前処理 # データ読み込み iris = datasets.load_iris() data = iris['data'] target = iris['target'] # 学習データと検証データに分割 x_train, x_valid, y_train, y_valid = train_test_split(data, target, shuffle=True) # 特徴量の標準化 scaler = StandardScaler() scaler.fit(x_train) x_train = scaler.transform(x_train) x_valid = scaler.transform(x_valid) # Tensor型に変換 # 学習に入れるときはfloat型になっている必要があるのここで変換してしまう x_train = torch.from_numpy(x_train).float() y_train = torch.from_numpy(y_train).long() x_valid = torch.from_numpy(x_valid).float() y_valid = torch.from_numpy(y_valid).long() print('x_train : ', x_train.shape) print('y_train : ', y_train.shape) print('x_valid : ', x_valid.shape) print('y_valid : ', y_valid.shape) # Datasetの作成 train_dataset = TensorDataset(x_train, y_train) valid_dataset = TensorDataset(x_valid, y_valid) # 動作確認 # indexを指定すればデータを取り出すことができます。 index = 0 print(train_dataset.__getitem__(index)[0].size()) print(train_dataset.__getitem__(index)[1]) # DataLoaderの作成 batch_size = 32 train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) # 動作確認 # こんな感じでバッチ単位で取り出す子ができます。 # イテレータに変換 batch_iterator = iter(train_dataloader) # 1番目の要素を取り出す inputs, labels = next(batch_iterator) print(inputs.size()) print(labels.size()) # Lightningモジュール class Net(pl.LightningModule): # ネットワークで使用する層を記述する def __init__(self): super().__init__() self.fc1 = nn.Linear(4, 50) self.fc2 = nn.Linear(50, 3) # 順伝播処理を記述する def forward(self, x): x = self.fc1(x) x = F.relu(x) x = self.fc2(x) x = F.softmax(x, dim=1) return x # 学習の際に使用する損失関数を記述している def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return {'loss': loss, 'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)} # 各エポック終了時の処理を記述している def training_epoch_end(self, train_step_outputs): y_hat = torch.cat([val['y_hat'] for val in train_step_outputs], dim=0) y = torch.cat([val['y'] for val in train_step_outputs], dim=0) epoch_loss = sum([val['batch_loss'] for val in train_step_outputs]) / y_hat.size(0) preds = torch.argmax(y_hat, dim=1) acc = accuracy(preds, y) self.log('train_loss', epoch_loss, prog_bar=True, on_epoch=True) self.log('train_acc', acc, prog_bar=True, on_epoch=True) print('-------- Current Epoch {} --------'.format(self.current_epoch + 1)) print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loss, acc)) def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return {'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)} def validation_epoch_end(self, val_step_outputs): # x_hatを一つにまとめる y_hat = torch.cat([val['y_hat'] for val in val_step_outputs], dim=0) y = torch.cat([val['y'] for val in val_step_outputs], dim=0) epoch_loss = sum([val['batch_loss'] for val in val_step_outputs]) / y_hat.size(0) preds = torch.argmax(y_hat, dim=1) acc = accuracy(preds, y) self.log('val_loss', epoch_loss, prog_bar=True, on_epoch=True) self.log('val_acc', acc, prog_bar=True, on_epoch=True) print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc)) # 最適化手法を記述する def configure_optimizers(self): optimizer = optim.SGD(self.parameters(), lr=0.01) return optimizer # Lightningモジュールの作成 net = Net() # 学習処理の設定 # EarlyStoppingの設定 es = pl.callbacks.EarlyStopping(monitor='val_loss') trainer = pl.Trainer( max_epochs=30, callbacks=[es], # GPUを使用する場合 # gpus=2 ) # 学習の実行 trainer.fit( net, # ネットワーク train_dataloader=train_dataloader, # 学習データ val_dataloaders=valid_dataloader, # 検証データ )