PyTorchで実装するPSPNet DataLoader実装偏 ④

PSPNetを実装していく前に、データ読み込みや前処理を自動で適用するためのDatasetクラスを作成し、ミニバッチ学習の時にデータを取りだしやすくするためのDataLoaderを実装していきます。

ディレクトリ構成は以下のような構成としています。

├─ data
|   └─ VOCdevkit
|      └─ VOC2012
|           ├─ Annotations
|           ├─ ImageSets
|           ├─ JPEGImages
|           ├─ SegmentationClass
|           └─ SegmentationObject
└─ util
    └─ dataloader.py  # ここに今回の処理を記述していく

次に今回使用するライブラリです。

import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.utils.data as data
from torchvision import transforms

最後に全体のコードも載せているので、実際に動かしながら確認していただければと思います。




1. ファイルパスのリストを作成

PASCAL VOC 2012では、学習データと検証データに分割したリストが公開されています。

今回は、このファイルを基準に学習データと検証データに分割していきます。

ファイルの場所はVOCdevkit/VOC2012/ImageSets/Segmentation配下にあります。

def make_datapath_list(rootpath):
    """
    学習用と検証用の画像データとアノテーションデータのファイルパスを格納したリストを取得する
    
    Args:
        rootpath(str): データフォルダへのパス
    
    Returns:
        train_img_list(list): 学習用の画像データへのファイルパス
        train_anno_list(list): 学習用のアノテーションデータへのファイルパス
        valid_img_list(list): 検証用の画像データへのファイルパス
        valid_anno_list(list): 検証用のアノテーションデータへのファイルパス
    """
    
    # 学習用画像の一覧を取得
    with open(os.path.join(rootpath, 'ImageSets/Segmentation/train.txt'), mode='r') as f:
        train_list = f.readlines()
        train_list = [val.replace('\n', '') for val in train_list]

    # 検証用画像の一覧を取得
    with open(os.path.join(rootpath, 'ImageSets/Segmentation/val.txt'), mode='r') as f:
        valid_list = f.readlines()
        valid_list = [val.replace('\n', '') for val in valid_list]
    
    # 学習用データのリストを作成
    train_img_list = [os.path.join(rootpath, f'JPEGImages/{val}.jpg') for val in train_list]
    train_anno_list = [os.path.join(rootpath, f'SegmentationClass/{val}.png') for val in train_list]
    
    # 検証用データのリストを作成
    valid_img_list = [os.path.join(rootpath, f'JPEGImages/{val}.jpg') for val in valid_list]
    valid_anno_list = [os.path.join(rootpath, f'SegmentationClass/{val}.png') for val in valid_list]
    
    return train_img_list, train_anno_list, valid_img_list, valid_anno_list

実際に処理を実行すると、学習データと検証データのファイルパスが格納されたリストが返されます。

rootpath = '../data/VOCdevkit/VOC2012/'
train_img_list, train_anno_list, valid_img_list, valid_anno_list = make_datapath_list(rootpath)

print(train_img_list[0])
print(train_anno_list[0])
# Output
../data/VOCdevkit/VOC2012/JPEGImages/2007_000032.jpg
../data/VOCdevkit/VOC2012/SegmentationClass/2007_000032.png




2. 前処理クラスの作成

画像とアノテーションデータの前処理を行うDataTrasnformクラスを作成します。

学習時と推論時で異なる動作をするように設定します。

セグメンテーションタスクの前処理として、注意点があります。

画像分類の時とはことなり、画像データをリサイズなどして変換した場合、それに対応したアノテーションデータも同時に変換する必要があります。

そうしないと、画像データとアノテーションデータとで、ずれが生じてしまい正しいアノテーションデータとならないためです。

画像データと同時にアノテーションデータを変換する処理はPyTorchに用意されていないため、自分で作成する必要があります。

セグメンテーションタスク用の前処理を順番に適用していくComposeクラスと、データ拡張のためのクラスも同時に作成していきます。

物体検知用にはなりますが、実装としては次のコードを参考にしています。

https://github.com/amdegroot/ssd.pytorch


2.1. Composeクラスの作成

前処理用のクラスを順番に適用していくComposeクラスというのがあります。

PyTorchで実装されているComposeクラスは次のようになっています。

class Compose:
    """
    ~~~
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

__call__部分を見てわかるように、画像データのみを変換しています。

セグメンテーションでは、前処理をする際には画像データとアノテーションデータも変換する必要があるので、次のようにComposeクラスを変更します。

class Compose(object):
    """
    指定した前処理を順次適用していくクラス
    
    Args:
        transforms (List[Transform]): 変換処理を格納したリスト
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, anno):
        for t in self.transforms:
            img, anno = t(img, anno)
        return img, anno

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

変更した点としては、__call__部分に画像データとアノテーションデータが渡せるように変更しました。


2.2. データ整形用クラスの作成

データ整形用のクラスもComposeクラスと同様に、画像データとアノテーションデータに対して変換できるようなクラスを作成します。

今回実装したデータ整形用クラスは、次の二つになります。

  • 画像データをリサイズさせるクラス
  • PIL画像をTensorに変換させ、正規化させるクラス

実装は次の通りです。

class Resize(object):
    """
    指定したサイズにリサイズする
    """

    def __init__(self, size=475):
        self.size = size

    def __call__(self, image, anno_img):
        image = image.resize((self.size, self.size), Image.BICUBIC)
        anno_img = anno_img.resize((self.size, self.size), Image.NEAREST)
        return image, anno_img
class Normalize(object):
    """
    画像データを0~1に正規化する
    """
    def __init__(self):
        pass

    def __call__(self, image, anno_img):
        # 画像データをPILからTensorに変換
        image = transforms.functional.to_tensor(image)
        
        # 0~1に正規化
        image = image / 255
        
        # アノテーション画像をNumpyに変換する
        anno_img = np.array(anno_img)
        
        # 境界値である255を0(backgroud)に変換する
        index = np.where(anno_img == 255)
        anno_img[index] = 0
        
        # アノテーション画像をTensorに変換する
        anno_img = torch.from_numpy(anno_img)

        return image, anno_img


2.3. データ拡張用クラスの作成

データ拡張用のクラスもComposeクラスと同様に、画像データとアノテーションデータに対して変換できるようなクラスを作成します。

今回実装したデータ拡張用クラスは、次の二つになります。

  • ランダムに上下を反転させるクラス
  • ランダムに左右を反転させるクラス

実装は次の通りです。

class RandomVerticalFlip(object):
    """
    ランダムに画像の上下を反転させる
    """

    def __init__(self, p=0.5):
        self.p = p
        
    def __call__(self, img, anno_img):
        
        if torch.rand(1) < self.p:
            return transforms.functional.vflip(img), transforms.functional.vflip(anno_img)
        
        return img, anno_img
class RandomHorizontalFlip(object):
    """
    ランダムに画像の左右を反転させる
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img, anno_img):
        
        if torch.rand(1) < self.p:
            return transforms.functional.hflip(img), transforms.functional.hflip(anno_img)
        
        return img, anno_img


2.4. DataTransformクラスの作成

上の3つで準備した前処理用のクラスを使用して、画像データとアノテーションデータの前処理クラスを作成します。

訓練時はデータ拡張を実施し、推論時はデータ拡張は実施しないといった動作が異なります。

class DataTransform(object):
    """
    画像データとアノテーションデータの前処理クラス。
    訓練時はデータ拡張を実施し、推論時はデータ拡張を実施しない。
    
    Attributes:
        input_size(int): リサイズ先の画像の大きさ
    """
    def __init__(self, input_size):
        self.data_transform = {
            'train': Compose([
                Resize(input_size),
                Normalize(),
                RandomHorizontalFlip(),
                RandomVerticalFlip()
            ]),
            'valid': Compose([
                Resize(input_size),
                Normalize()
            ])
        }
    
    def __call__(self, phase, img, anno_img):
        return self.data_transform[phase](img, anno_img)

稼働確認をすると、画像データとアノテーションデータが前処理された結果が返されることがわかります。

# ファイルパスのリストを取得する
rootpath = '../data/VOCdevkit/VOC2012/'
train_img_list, train_anno_list, valid_img_list, valid_anno_list = make_datapath_list(rootpath)

train_img_path = train_img_list[0]
train_anno_img_path = train_anno_list[0]

# 稼働確認として画像ファイルを、サンプルとして読み込み
img = Image.open(train_img_path)
anno_img = Image.open(train_anno_img_path)

# 前処理の実行
transformer = DataTransform(input_size=475)
img, anno_img = transformer('train', img, anno_img)

print(img.shape)
print(anno_img.shape)
# Output
torch.Size([3, 475, 475])
torch.Size([475, 475])

アノテーションデータは、可視化する際にインデックスカラーで可視化していないので少し見にくいですが、アノテーションの対応関係が崩れることなく前処理ができています。

また、学習時の設定で前処理を何回か実行すると、上下または左右が逆になる画像がランダムに出力されることもわかります。

# 可視化のためにNumpyに変換
img_array = np.array(img) * 255
img_array = img_array.transpose(1, 2, 0)
anno_img_array = np.array(anno_img)

# 可視化
fig = plt.figure(figsize=(15, 6))

ax1 = fig.add_subplot(1, 2, 1)
ax2 = fig.add_subplot(1, 2, 2)

ax1.imshow(img_array)
ax2.imshow(anno_img_array)

plt.show()




3. Datasetクラスの作成

PyTorchのDatasetクラスを継承させたクラスを作ります。

このクラスは後で説明するDataLoaderクラスに組み込むときに役立ちます。

処理を記述する箇所は「__len__」「__getitem__」の二つです。

「__len__」にはDatasetに含まれるデータ数を返す処理を記述します。

「__getitem__」にはindex番号を引数にとり、学習データとラベルデータを返す処理を記述します。

class VOCDataset(data.Dataset):
    """
    VOC2012のDatasetを作成するクラス
    
    Attributes:
        img_list(list): 画像データのファイルパスを格納したリスト
        anno_img_list(list): アノテーションデータのファイルパスを格納したリスト
        phase(str): 'train' or 'valid'
        transform(object): 前処理クラスのインスタンス
    """
    
    def __init__(self, img_list, anno_img_list, phase, transform):
        self.img_list = img_list
        self.anno_img_list = anno_img_list
        self.phase = phase
        self.transform = transform
        
    def __len__(self):
        """
        画像の枚数を返す
        """
        return len(self.img_list)
    
    def __getitem__(self, index):
        """
        前処理後のTensor形式の画像データとアノテーションデータを返す
        """
        img, anno_img = self.pull_item(index)
        return img, anno_img

    def pull_item(self, index):
        """
        画像データとアノテーションデータを読み込み、前処理を実施する
        """
        # 読み込み
        img_file_path = self.img_list[index]
        anno_img_file_path = self.anno_img_list[index]
        
        img = Image.open(img_file_path)
        anno_img = Image.open(anno_img_file_path)
        
        # 前処理の実施
        img, anno_img = self.transform(self.phase, img, anno_img)
        
        return img, anno_img

実装したクラスを使用してデータセットを作成してみます。

実際にデータセットが作成できていることがわかります。

# ファイルパスのリストを取得する
rootpath = '../data/VOCdevkit/VOC2012/'
train_img_list, train_anno_list, valid_img_list, valid_anno_list = make_datapath_list(rootpath)

# データセットの作成
train_dataset = VOCDataset(
    img_list=train_img_list,
    anno_img_list=train_anno_list,
    phase='train',
    transform=DataTransform(input_size=475)
)

valid_dataset = VOCDataset(
    img_list=valid_img_list,
    anno_img_list=valid_anno_list,
    phase='valid',
    transform=DataTransform(input_size=475)
)

# データセットの取り出し
print(valid_dataset.__getitem__(0)[0].shape)
print(valid_dataset.__getitem__(0)[1].shape)
# Output
torch.Size([3, 475, 475])
torch.Size([475, 475])




4. DataLoaderの作成

バッチ処理を適用するためにDataLoaderを作成して、データセットをバッチ単位で取り出せるようにします。

PyTorchで実装されているDataLoaderに上で作成したDatasetを渡してやれば、DataLoaderを作成することができます。

# DataLoaderの作成
batch_size = 8

train_dataloader = data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)

valid_dataloader = data.DataLoader(
    valid_dataset,
    batch_size=batch_size,
    shuffle=False
)

# 動作確認
batch_iterator = iter(valid_dataloader)
imgs, anno_imgs = next(batch_iterator)

print(imgs.shape)
print(anno_imgs.shape)
# Output
torch.Size([8, 3, 475, 475])
torch.Size([8, 475, 475])

画像データとアノテーションデータで、それぞれ8個のデータがDataLoaderから取り出せることがわかると思います。




5. まとめ

以上でDataLoaderの実装が完了になります。

DataLoaderを使用すれば、ミニバッチ学習の時など学習の際にデータを取り出しやすくなり便利になります。

最後に全体のコードを載せておきます。


6. 参考


7. 全体コード

# util.dataloader.py
### ライブラリ
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.utils.data as data
from torchvision import transforms


### クラス定義
class Compose(object):
    """
    指定した前処理を順次適用していくクラス
    
    Args:
        transforms (List[Transform]): 変換処理を格納したリスト
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, anno):
        for t in self.transforms:
            img, anno = t(img, anno)
        return img, anno

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

    
class Resize(object):
    """
    指定したサイズにリサイズする
    """

    def __init__(self, size=475):
        self.size = size

    def __call__(self, image, anno_img):
        image = image.resize((self.size, self.size), Image.BICUBIC)
        anno_img = anno_img.resize((self.size, self.size), Image.NEAREST)
        return image, anno_img

    
class Normalize(object):
    """
    画像データを0~1に正規化する
    """
    def __init__(self):
        pass

    def __call__(self, image, anno_img):
        # 画像データをPILからTensorに変換
        image = transforms.functional.to_tensor(image)
        
        # 0~1に正規化
        image = image / 255.
        
        # アノテーション画像をNumpyに変換する
        anno_img = np.array(anno_img)
        
        # 境界値である255を0(backgroud)に変換する
        index = np.where(anno_img == 255)
        anno_img[index] = 0
        
        # アノテーション画像をTensorに変換する
        anno_img = torch.from_numpy(anno_img)

        return image, anno_img

    
class RandomVerticalFlip(object):
    """
    ランダムに画像の上下を反転させる
    """

    def __init__(self, p=0.5):
        self.p = p
        
    def __call__(self, img, anno_img):
        
        if torch.rand(1) < self.p:
            return transforms.functional.vflip(img), transforms.functional.vflip(anno_img)
        
        return img, anno_img
    
    
class RandomHorizontalFlip(object):
    """
    ランダムに画像の左右を反転させる
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img, anno_img):
        
        if torch.rand(1) < self.p:
            return transforms.functional.hflip(img), transforms.functional.hflip(anno_img)
        
        return img, anno_img
    
    
class DataTransform(object):
    """
    画像データとアノテーションデータの前処理クラス。
    訓練時はデータ拡張を実施し、推論時はデータ拡張を実施しない。
    
    Attributes:
        input_size(int): リサイズ先の画像の大きさ
    """
    def __init__(self, input_size):
        self.data_transform = {
            'train': Compose([
                Resize(input_size),
                Normalize(),
                RandomHorizontalFlip(),
                RandomVerticalFlip()
            ]),
            'valid': Compose([
                Resize(input_size),
                Normalize()
            ])
        }
    
    def __call__(self, phase, img, anno_img):
        return self.data_transform[phase](img, anno_img)
    
    
class VOCDataset(data.Dataset):
    """
    VOC2012のDatasetを作成するクラス
    
    Attributes:
        img_list(list): 画像データのファイルパスを格納したリスト
        anno_img_list(list): アノテーションデータのファイルパスを格納したリスト
        phase(str): 'train' or 'valid'
        transform(object): 前処理クラスのインスタンス
    """
    
    def __init__(self, img_list, anno_img_list, phase, transform):
        self.img_list = img_list
        self.anno_img_list = anno_img_list
        self.phase = phase
        self.transform = transform
        
    def __len__(self):
        """
        画像の枚数を返す
        """
        return len(self.img_list)
    
    def __getitem__(self, index):
        """
        前処理後のTensor形式の画像データとアノテーションデータを返す
        """
        img, anno_img = self.pull_item(index)
        return img, anno_img

    def pull_item(self, index):
        """
        画像データとアノテーションデータを読み込み、前処理を実施する
        """
        # 読み込み
        img_file_path = self.img_list[index]
        anno_img_file_path = self.anno_img_list[index]
        
        img = Image.open(img_file_path)
        anno_img = Image.open(anno_img_file_path)
        
        # 前処理の実施
        img, anno_img = self.transform(self.phase, img, anno_img)
        
        return img, anno_img
    

### 関数定義
def make_datapath_list(rootpath):
    """
    学習用と検証用の画像データとアノテーションデータのファイルパスを格納したリストを取得する
    
    Args:
        rootpath(str): データフォルダへのパス
    
    Returns:
        train_img_list(list): 学習用の画像データへのファイルパス
        train_anno_list(list): 学習用のアノテーションデータへのファイルパス
        valid_img_list(list): 検証用の画像データへのファイルパス
        valid_anno_list(list): 検証用のアノテーションデータへのファイルパス
    """
    
    # 学習用画像の一覧を取得
    with open(os.path.join(rootpath, 'ImageSets/Segmentation/train.txt'), mode='r') as f:
        train_list = f.readlines()
        train_list = [val.replace('\n', '') for val in train_list]

    # 検証用画像の一覧を取得
    with open(os.path.join(rootpath, 'ImageSets/Segmentation/val.txt'), mode='r') as f:
        valid_list = f.readlines()
        valid_list = [val.replace('\n', '') for val in valid_list]
    
    # 学習用データのリストを作成
    train_img_list = [os.path.join(rootpath, f'JPEGImages/{val}.jpg') for val in train_list]
    train_anno_list = [os.path.join(rootpath, f'SegmentationClass/{val}.png') for val in train_list]
    
    # 検証用データのリストを作成
    valid_img_list = [os.path.join(rootpath, f'JPEGImages/{val}.jpg') for val in valid_list]
    valid_anno_list = [os.path.join(rootpath, f'SegmentationClass/{val}.png') for val in valid_list]
    
    return train_img_list, train_anno_list, valid_img_list, valid_anno_list


# 動作確認
if __name__ == '__main__':
    
    ### 1. ファイルパスのリストを作成

    # 学習用と検証用の画像データとアノテーションデータのファイルパスを格納したリストを取得する
    rootpath = '../data/VOCdevkit/VOC2012/'

    train_img_list, train_anno_list, valid_img_list, valid_anno_list = make_datapath_list(rootpath)

    print(train_img_list[0])
    print(train_anno_list[0])


    ### 2. 前処理クラスの作成

    # 稼働確認として画像ファイルを、サンプルとして読み込み
    train_img_path = train_img_list[0]
    train_anno_img_path = train_anno_list[0]

    img = Image.open(train_img_path)
    anno_img = Image.open(train_anno_img_path)

    # 前処理の実行
    transformer = DataTransform(input_size=475)
    img, anno_img = transformer('train', img, anno_img)

    print(img.shape)
    print(anno_img.shape)

    img_array = np.array(img) * 255
    img_array = img_array.transpose(1, 2, 0)

    anno_img_array = np.array(anno_img)

    fig = plt.figure(figsize=(15, 6))

    ax1 = fig.add_subplot(1, 2, 1)
    ax2 = fig.add_subplot(1, 2, 2)

    ax1.imshow(img_array)
    ax2.imshow(anno_img_array)

    plt.show()


    # 3. Datasetクラスの作成

    # 学習データのDataset
    train_dataset = VOCDataset(
        img_list=train_img_list,
        anno_img_list=train_anno_list,
        phase='train',
        transform=DataTransform(input_size=475)
    )

    # 検証データのDataset
    valid_dataset = VOCDataset(
        img_list=valid_img_list,
        anno_img_list=valid_anno_list,
        phase='valid',
        transform=DataTransform(input_size=475)
    )

    # データセットの取り出し
    print(valid_dataset.__getitem__(0)[0].shape)
    print(valid_dataset.__getitem__(0)[1].shape)


    ### 4. DataLoaderの作成

    batch_size = 8

    # 学習データのDataLoader
    train_dataloader = data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    # 検証データのDataLoader
    valid_dataloader = data.DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False
    )

    # 動作確認
    batch_iterator = iter(valid_dataloader)
    imgs, anno_imgs = next(batch_iterator)

    print(imgs.shape)
    print(anno_imgs.shape)