【PyTorch入門】TensorのNumpy変換

1. 概要

PyTorchにはTensorというものが存在します。

TensorはPyTorchの基本となるデータ構造で、多次元配列を扱います。

PyTorchでTensorをモデルの入力・出力・モデルのパラメータなどで使用します。

Tensorの操作方法はNumpyの操作方法と似ており、Numpyと同じような感覚で操作できます。

PyTorchのTensorからNumpyのndarrayへの変換と、NumpyのndarrayからPyTorchのTensorへの変換方法を紹介します。


2. 「torch.Tensor」から「numpy.ndarray」への変換

PyTorchのTensor型の作成はtorch.tensorを使います。

ndarrayへの変換にはnumpy()を呼び出せば、変換することができます。

# Tensorを用意
x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32)

print('変換前のデータ型と中身の確認')
print('データ型 : ', type(x))
print('データ : \n', x)

# numpy.arrayへ変換
x = x.numpy()

print('変換後のデータ型と中身の確認')
print('データ型 : ', type(x))
print('データ : \n', x)
# Output
変換前のデータ型と中身の確認
データ型 :  <class 'torch.Tensor'>
データ : 
 tensor([[1, 2, 3],
        [4, 5, 6]], dtype=torch.int32)

変換後のデータ型と中身の確認
データ型 :  <class 'numpy.ndarray'>  # ndarray型に変換できています
データ : 
 [[1 2 3]
 [4 5 6]]



3. 「numpy.ndarray」から「torch.Tensor」への変換

ndarrayからTensor型への変換はfrom_numpy()を呼び出せば、変換することができます。

# arrayの用意
x = np.array([[1, 2, 3], [4, 5, 6]])

print('変換前のデータ型と中身の確認')
print('データ型 : ', type(x))
print('データ : \n', x)


# torch.Tensorへ変換
x = torch.from_numpy(x)

print('変換後のデータ型と中身の確認')
print('データ型 : ', type(x))
print('データ : \n', x)
# Output
変換前のデータ型と中身の確認
データ型 :  <class 'numpy.ndarray'>
データ : 
 [[1 2 3]
 [4 5 6]]

変換後のデータ型と中身の確認
データ型 :  <class 'torch.Tensor'>  # Tensor型に変換ができています
データ : 
 tensor([[1, 2, 3],
        [4, 5, 6]], dtype=torch.int32)




4. 全体コード

最後に全体のコードをのせておきます。

# ライブラリの読み込み
import numpy as np
import torch


# torch.Tensorからnumpy.ndarrayへの変換

# Tensorを用意
x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32)

print('変換前のデータ型と中身の確認')
print('データ型 : ', type(x))
print('データ : \n', x)

# numpy.arrayへ変換
x = x.numpy()

print('変換後のデータ型と中身の確認')
print('データ型 : ', type(x))
print('データ : \n', x)


# numpy.ndarrayからtorch.Tensorへの変換

# arrayの用意
x = np.array([[1, 2, 3], [4, 5, 6]])

print('変換前のデータ型と中身の確認')
print('データ型 : ', type(x))
print('データ : \n', x)


# torch.Tensorへ変換
x = torch.from_numpy(x)

print('変換後のデータ型と中身の確認')
print('データ型 : ', type(x))
print('データ : \n', x)