PyTorchの組み込みデータセット




1. 組み込みデータセットについて

PyTorchのtorchvisionには、多くの組み込みデータセットが用意されており、これを使うことで学習用のデータを簡単に用意することができます。加えて、これらのデータセットは、DataLoaderに渡すだけで、簡単に利用することが可能です。


torchvisionのデータセットAPIは一貫性があり、基本的に同じAPIを介してどのデータセットも使用することができます。さらに、データセットの画像を変換するためのtransformパラメータや、ラベルデータを変換するtarget_transformパラメータを利用することで、データ拡張なども自由に設定することができます。



2. 組み込みデータセットの使い方

torchvisionから組み込みデータセットを呼び出して、Datasetを作成します。データ拡張などを実施する場合はtransformパラメータにデータ拡張のクラスを適用します。Datasetなので__getitem__でデータを取り出すことができます。

# ライブラリ
import torch
import torchvision
from torchvision import transforms

# Datasetの作成
train_dataset = torchvision.datasets.FashionMNIST(
    # データセットを保存するディレクトリ
    root='./',
    # 学習用データセット
    train=True,
    # ダウンロードするかどうか
    download=True,
    # 画像の変換
    transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor()
    ])
)

# データの取得
# __getitem__に、インデックスを指定することでデータを取り出すことができる
image, label = train_dataset.__getitem__(0)

print('image.shape : ', image.shape)
print('label : ', label)
# Output
image.shape :  torch.Size([1, 28, 28])
label :  9


データセットを作成したら、Dataloaderにデータセットを適用します。

# Dataloaderの作成
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=4,
    shuffle=False
)

# データの取り出し
batch_iterator = iter(train_dataloader)
images, labels = next(batch_iterator)

print('images.shape : ', images.shape)
print('labels : ', labels)
# Output
images.shape :  torch.Size([4, 1, 28, 28])
labels :  tensor([9, 0, 0, 3])



3. Datasetの種類

組み込みデータセットは、いろいろなタスクに対応したデータが用意されています。Image Classification, Image Detection, Image Segmentationのデータセットを確認してみたいと思います。


他のデータセット一覧や、詳細なパラメータについてはPyTorchの公式ドキュメントを参照してみてください。

pytorch.org


1. Classification

画像とその画像がどのクラスに属するかのラベルデータが格納されています。

# Datasetの作成
train_dataset = torchvision.datasets.FashionMNIST(
    root='./',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

# データの取得
image, label = train_dataset.__getitem__(0)

print('image.shape : ', image.shape)
print('label : ', label)
# Output
image.shape :  torch.Size([1, 28, 28])
label :  9



2. Image Detection

物体検出は、ラベルデータとして物体のバウンディングボックスとバウンディングボックス内に存在するクラスラベルが与えられます。Datasetから辞書形式でラベル情報を受け取ることができます。

# VOC Dataset
train_dataset = torchvision.datasets.VOCDetection(
    # データセットを保存するディレクトリ
    root='./',
    # データセットの種類(2007, 2012)
    year='2012',
    # 学習用データセット
    image_set='train',
    # ダウンロードするかどうか
    download=True,
    # 画像の変換
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

# データの取得
image, label = train_dataset.__getitem__(0)

print('image.shape : ', image.shape)
print('labels : ', labels)
# Output
image.shape :  torch.Size([3, 442, 500])
labels :  {'annotation': {'folder': 'VOC2012', 'filename': '2008_000008.jpg', 'source': {'database': 'The VOC2008 Database', 'annotation': 'PASCAL VOC2008', 'image': 'flickr'}, 'size': {'width': '500', 'height': '442', 'depth': '3'}, 'segmented': '0', 'object': [{'name': 'horse', 'pose': 'Left', 'truncated': '0', 'occluded': '1', 'bndbox': {'xmin': '53', 'ymin': '87', 'xmax': '471', 'ymax': '420'}, 'difficult': '0'}, {'name': 'person', 'pose': 'Unspecified', 'truncated': '1', 'occluded': '0', 'bndbox': {'xmin': '158', 'ymin': '44', 'xmax': '289', 'ymax': '167'}, 'difficult': '0'}]}}



3. Image Segmentation

セグメンテーションは、ラベルデータとして各画素値に対してどのクラスに属するかのインデックスが割り振られています。注意点として、RGB形式で読み込んでしまうと、アノテーションデータとして使えなくなってしまいます。ラベルデータはパレットモードで読み込む必要があります。

# VOC Dataset
train_dataset = torchvision.datasets.VOCSegmentation(
    # データセットを保存するディレクトリ
    root='./',
    # データセットの種類(2007, 2012)
    year='2012',
    # 学習用データセット
    image_set='train',
    # ダウンロードするかどうか
    download=True,
)

# データの取得
image, label = train_dataset.__getitem__(0)

plt.imshow(image)
plt.show()

plt.imshow(label)
plt.show()