更新:2024/10/03
PyTorchの基本要素(Tensor、Autograd、nn、Dataset)のまとめについて


はるか
最近、PyTorchに興味が出てきた。

ふゅか
いいじゃん!まずはテンソルから始めてみようよ。テンソルはPyTorchの基本だしね!
目次
1. Tensor

はるか
テンソルって多次元配列のこと?

ふゅか
そうそう!GPUを使って高速計算もできるんだ!
テンソルは、PyTorchにおける基本的なデータ構造であり、多次元の数値配列を表現します。NumPyのndarray
に似ていますが、PyTorchのテンソルはGPUを利用して高速な計算が可能であり、深層学習に特化した機能があります。
1.1. 多次元配列
テンソルは0次元(スカラー)から1次元(ベクトル)、2次元(行列)、さらに高次元のデータまで扱えます。
1.2. CPUとGPUのデバイス対応
テンソルはCPUとGPUの両方で操作可能です。.to(device)
メソッドでデバイス間の移動が簡単に行えます。
1.3. データ型
浮動小数点数、整数、ブール値など、多様なデータ型をサポートしています。
1.4. 例
import torch
# CPU上のテンソル
x = torch.Tensor([[1, 2], [3, 4]])
# GPU上のテンソル(CUDAが利用可能な場合)
if torch.cuda.is_available():
x = x.to('cuda')
2. Autograd(自動微分)

はるか
Autogradっていうのも気になる。

ふゅか
それは自動微分の仕組みだよ!勾配を自動で計算してくれるから、とっても便利なんだ。
2.1. 勾配計算
requires_grad=True
を設定すると、そのテンソルに対する操作が追跡され、backward()
メソッドで勾配を計算できます。
2.2. 例
import torch
# 勾配計算を有効にしたテンソル
x = torch.tensor(2.0, requires_grad=True)
y = x ** 5
y.backward()
print(x.grad) # dy/dx = 5x^4 ⇒ 5 * 2.0^4 = 80.
3. nnモジュール(ニューラルネットワーク)
torch.nn
は、ニューラルネットワークを簡単に構築するためのモジュールです。
3.1. レイヤー
nn.Linear
, nn.Conv2d
, nn.LSTM
などの層を提供します。
3.2. 損失関数
nn.CrossEntropyLoss
, nn.MSELoss
など、多様な損失関数が利用可能です。
3.3. モデル構築
nn.Module
を継承して、自身のモデルを定義できます。
3.4. 例
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)

はるか
自分でモデルを定義できるのは面白い。
4. Optimizer(最適化アルゴリズム)

はるか
Optimizerについても教えて。

ふゅか
モデルのパラメータを更新するアルゴリズムだよ!
torch.optim
は、モデルのパラメータを更新するための最適化アルゴリズムを提供します。
4.1. アルゴリズム
SGD, Adam, RMSpropなどが利用可能です。学習率などのハイパーパラメータも設定可能です。
4.2. 例
import torch.optim as optim
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
5. DataLoaderとDataset
データを効率的に読み込むためのツールが提供されています。
5.1. Dataset
データセットを扱うためのクラスで、__len__
と__getitem__
を実装します。
5.2. DataLoader
バッチ処理、シャッフル、マルチプロセッシングによるデータ読み込みをサポートします。
5.3. 例
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self):
# データの初期化
pass
def __len__(self):
# データセットの長さを返す
return 100
def __getitem__(self, idx):
# インデックスidxのデータを返す
return data
dataset = CustomDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
PR