1. 組み込みデータセットについて
PyTorchのtorchvisionには、多くの組み込みデータセットが用意されており、これを使うことで学習用のデータを簡単に用意することができます。加えて、これらのデータセットは、DataLoader
に渡すだけで、簡単に利用することが可能です。
torchvisionのデータセットAPIは一貫性があり、基本的に同じAPIを介してどのデータセットも使用することができます。さらに、データセットの画像を変換するためのtransform
パラメータや、ラベルデータを変換するtarget_transform
パラメータを利用することで、データ拡張なども自由に設定することができます。
2. 組み込みデータセットの使い方
torchvisionから組み込みデータセットを呼び出して、Dataset
を作成します。データ拡張などを実施する場合はtransform
パラメータにデータ拡張のクラスを適用します。Dataset
なので__getitem__
でデータを取り出すことができます。
# ライブラリ import torch import torchvision from torchvision import transforms # Datasetの作成 train_dataset = torchvision.datasets.FashionMNIST( # データセットを保存するディレクトリ root='./', # 学習用データセット train=True, # ダウンロードするかどうか download=True, # 画像の変換 transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor() ]) ) # データの取得 # __getitem__に、インデックスを指定することでデータを取り出すことができる image, label = train_dataset.__getitem__(0) print('image.shape : ', image.shape) print('label : ', label)
# Output image.shape : torch.Size([1, 28, 28]) label : 9
データセットを作成したら、Dataloader
にデータセットを適用します。
# Dataloaderの作成 train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=4, shuffle=False ) # データの取り出し batch_iterator = iter(train_dataloader) images, labels = next(batch_iterator) print('images.shape : ', images.shape) print('labels : ', labels)
# Output images.shape : torch.Size([4, 1, 28, 28]) labels : tensor([9, 0, 0, 3])
3. Datasetの種類
組み込みデータセットは、いろいろなタスクに対応したデータが用意されています。Image Classification, Image Detection, Image Segmentationのデータセットを確認してみたいと思います。
他のデータセット一覧や、詳細なパラメータについてはPyTorchの公式ドキュメントを参照してみてください。
1. Classification
画像とその画像がどのクラスに属するかのラベルデータが格納されています。
# Datasetの作成 train_dataset = torchvision.datasets.FashionMNIST( root='./', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor() ]) ) # データの取得 image, label = train_dataset.__getitem__(0) print('image.shape : ', image.shape) print('label : ', label)
# Output image.shape : torch.Size([1, 28, 28]) label : 9
2. Image Detection
物体検出は、ラベルデータとして物体のバウンディングボックスとバウンディングボックス内に存在するクラスラベルが与えられます。Datasetから辞書形式でラベル情報を受け取ることができます。
# VOC Dataset train_dataset = torchvision.datasets.VOCDetection( # データセットを保存するディレクトリ root='./', # データセットの種類(2007, 2012) year='2012', # 学習用データセット image_set='train', # ダウンロードするかどうか download=True, # 画像の変換 transform=transforms.Compose([ transforms.ToTensor() ]) ) # データの取得 image, label = train_dataset.__getitem__(0) print('image.shape : ', image.shape) print('labels : ', labels)
# Output image.shape : torch.Size([3, 442, 500]) labels : {'annotation': {'folder': 'VOC2012', 'filename': '2008_000008.jpg', 'source': {'database': 'The VOC2008 Database', 'annotation': 'PASCAL VOC2008', 'image': 'flickr'}, 'size': {'width': '500', 'height': '442', 'depth': '3'}, 'segmented': '0', 'object': [{'name': 'horse', 'pose': 'Left', 'truncated': '0', 'occluded': '1', 'bndbox': {'xmin': '53', 'ymin': '87', 'xmax': '471', 'ymax': '420'}, 'difficult': '0'}, {'name': 'person', 'pose': 'Unspecified', 'truncated': '1', 'occluded': '0', 'bndbox': {'xmin': '158', 'ymin': '44', 'xmax': '289', 'ymax': '167'}, 'difficult': '0'}]}}
3. Image Segmentation
セグメンテーションは、ラベルデータとして各画素値に対してどのクラスに属するかのインデックスが割り振られています。注意点として、RGB形式で読み込んでしまうと、アノテーションデータとして使えなくなってしまいます。ラベルデータはパレットモードで読み込む必要があります。
# VOC Dataset train_dataset = torchvision.datasets.VOCSegmentation( # データセットを保存するディレクトリ root='./', # データセットの種類(2007, 2012) year='2012', # 学習用データセット image_set='train', # ダウンロードするかどうか download=True, ) # データの取得 image, label = train_dataset.__getitem__(0) plt.imshow(image) plt.show() plt.imshow(label) plt.show()