アクセスありがとうございます!次は「とあるエンジニアのエソラゴト」で検索して頂けると嬉しいです!

【PyTorch】CNNでデータ拡張をしながら、CIFAR10を分類するサンプルコード【初心者】

はじめに

冬休みということでDeep Learningを勉強しています。

今日はCNNを使って、CIFAR10の画像データを分類してみようと思います。

また精度を上げるために、データ拡張も挑戦してみようと思っています。

前提

PyTorchを使います。

開発環境は「Google Colaboratory」を使っていきます。

AWSのSageMakerを使ってGPUを利用する場合は、以下をご参照下さい。

関連記事

私 冬休みということで、Deep Learningを集中的に勉強しています。 普段は「Google Colaboratory」を使っていますが、無料使用だとGPU使用に制限があるため、すぐにGPUが使えなくなってしまいました。[…]

テストデータはCIFAR10という画像データを使います。

CIFAR-10は、以下のdog(犬)やship(船)といった画像データを合計で10個集めて、正解ラベルをつけたデータセットです。

  1. plane(飛行機)
  2. car(車)
  3. bird(鳥)
  4. cat(猫)
  5. deer(鹿)
  6. dog(犬)
  7. frog(カエル)
  8. horse(馬)
  9. ship(船)
  10. truck(トラック)

MNISTという手書き文字データよりも複雑なものが画像データとして写っているので、学習難易度は上がる傾向にあります。

コードの全体はGitHubにアップロードしています。

サンプルコードの全体はこちら。

CIFAR10のデータ準備の実装サンプルコード

ライブラリのインポート、GPU使用

まずは必要なライブラリを、インポートします。

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

そしてGPUが使えたら使いたいので、「GPUが使えるのであればGPUを使うように、そうでなければCPUを使うように」という設定をします。

device = "cuda" if torch.cuda.is_available() else "cpu"

データの変換

そして、以下のコードでデータを変換する処理をコーディングします。

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))                                
])

transforms.ToTensor()はTensor型に変換をしています。

transforms.Normalize((0.5,), (0.5,))は第1引数が平均、第2引数が標準偏差を表し、それに合わせて正規化するように変換をします。

データ拡張

データ拡張とは手元で使える画像データが少ない場合に、プログラムからデータを傾けたり、色彩を変化させてデータを擬似的に増やす手法です。

Deep Learningの分野ではデータ量は多ければ多いほど、良い精度が出る傾向にあるので、データ量を増やすことは非常に大切です。

ただ、集めることが大変な時もあると思うので、是非覚えておきたい手法です。

やり方は簡単で、上記に実装したtransforms.Composeにコードを追加するだけです。

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    # チャネル毎に画像全体の平均値と標準偏差が0.5になるように標準化
    transforms.Normalize((0.5,), (0.5,))                                
])

transforms.RandomHorizontalFlip()は、「ランダムに画像データを水平反転」をします。

transforms.ColorJitter()は、「ランダムに画像の明るさ、コントラスト、彩度、色相を変える」ことが出来ます。

transforms.RandomRotation()は、「ランダムに回転をさせる」ことが出来ます。

CIFAR10のダウンロード

そして、CIFAR10のデータを用意します。

CIFAR10のサンプルデータはPyTorchのライブラリを使って、インターネット経由でダウンロードしてきます。

今回は実際に訓練を行う訓練データと、検証用に使用するテストデータの2つに分けます。

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=val_transform)

そして、以下のようにデータローダーにそれぞれのデータを代入します。

# 訓練用データ
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# テスト用データ
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

データの中身を見てみます。

data_iter = iter(train_dataloader)
imgs, labels = data_iter.next()

ラベルを出すると、以下のように数値の正解ラベルが入っています。

labels

# 以下は出力値
tensor([5, 9, 8, 4, 3, 4, 4, 7, 2, 6, 8, 0, 6, 0, 7, 2, 3, 6, 4, 8, 0, 7, 8, 6,
        5, 3, 8, 5, 9, 0, 4, 1])

では、写真を出力してみます。

img = imgs[0]
img_permute = img.permute(1, 2, 0)
img_permute = 0.5 * img_permute + 0.5
img_permute = np.clip(img_permute, 0, 1)
plt.imshow(img_permute)

鹿の画像が出ました!

CNNモデルの実装サンプルコード

では、ここからCNNのモデル実装をしていきます。

モデルの全体像は以下のようになります。

class CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.classifier = nn.Linear(in_features=4 * 4 * 128, out_features=num_classes)
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

nn.Conv2dで畳み込みフィルターを行います。

活性化関数にはReLUを使用します。

nn.MaxPool2dでマックスプーリングを行います。

model = CNN(10)
model.to(device)

損失関数はクロスエントロピーを選択し、最適化関数はAdamを使用します。

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

CNNモデルの訓練処理の実装サンプルコード

それでは、学習をさせます。

最初に訓練データでモデルを学習させて、テストデータで検証をします。

num_epochs = 15
losses = []
accs = []
val_losses = []
val_accs = []
for epoch in range(num_epochs):
    running_loss = 0.0
    running_acc = 0.0
    for imgs, labels in train_dataloader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        output = model(imgs)
        loss = criterion(output, labels)
        loss.backward()
        running_loss += loss.item()
        pred = torch.argmax(output, dim=1)
        running_acc += torch.mean(pred.eq(labels).float())
        optimizer.step()
    running_loss /= len(train_dataloader)
    running_acc /= len(train_dataloader)
    losses.append(running_loss)
    accs.append(running_acc)
    #
    # test loop
    #
    val_running_loss = 0.0
    val_running_acc = 0.0
    for val_imgs, val_labels in test_dataloader:
        val_imgs = val_imgs.to(device)
        val_labels = val_labels.to(device)
        val_output = model(val_imgs)
        val_loss = criterion(val_output, val_labels)
        val_running_loss += val_loss.item()
        val_pred = torch.argmax(val_output, dim=1)
        val_running_acc += torch.mean(val_pred.eq(val_labels).float())
    val_running_loss /= len(test_dataloader)
    val_running_acc /= len(test_dataloader)
    val_losses.append(val_running_loss)
    val_accs.append(val_running_acc)
    print("epoch: {}, loss: {}, acc: {}, \
     val loss: {}, val acc: {}".format(epoch, running_loss, running_acc, val_running_loss, val_running_acc))

すると、以下のような結果になります。

epoch: 0, loss: 1.433305179813468, acc: 0.47696736454963684,      val loss: 1.0972439880950002, val acc: 0.6093250513076782
epoch: 1, loss: 1.0305320225460592, acc: 0.6342170238494873,      val loss: 0.9065422707091505, val acc: 0.682707667350769
epoch: 2, loss: 0.8761650123469584, acc: 0.6921985149383545,      val loss: 0.8051631742011244, val acc: 0.7189496755599976
epoch: 3, loss: 0.7819633651870379, acc: 0.7235084772109985,      val loss: 0.7636748775125692, val acc: 0.7327276468276978
epoch: 4, loss: 0.7313383280353827, acc: 0.7460412979125977,      val loss: 0.7061512276006583, val acc: 0.7598841786384583
epoch: 5, loss: 0.6902052811034124, acc: 0.7585172653198242,      val loss: 0.710792664140939, val acc: 0.7541933059692383
epoch: 6, loss: 0.6601356580824861, acc: 0.7694937586784363,      val loss: 0.6813203324905981, val acc: 0.769069492816925
epoch: 7, loss: 0.6379130195709505, acc: 0.7775911688804626,      val loss: 0.653992762390417, val acc: 0.776457667350769
epoch: 8, loss: 0.6179025139533322, acc: 0.7821897268295288,      val loss: 0.6546257712399236, val acc: 0.7769568562507629
epoch: 9, loss: 0.5969862834868032, acc: 0.7906469702720642,      val loss: 0.6211730434586065, val acc: 0.7905351519584656
epoch: 10, loss: 0.5893858630288814, acc: 0.7949455976486206,      val loss: 0.634566101212852, val acc: 0.780451238155365
epoch: 11, loss: 0.5731508998118069, acc: 0.8000239729881287,      val loss: 0.6467042784816541, val acc: 0.7775558829307556
epoch: 12, loss: 0.5638545042264942, acc: 0.8032429814338684,      val loss: 0.6207450504024951, val acc: 0.7892372012138367
epoch: 13, loss: 0.5516865827758115, acc: 0.8065218925476074,      val loss: 0.6165681539442592, val acc: 0.7894368767738342
epoch: 14, loss: 0.5429838309666322, acc: 0.8096808791160583,      val loss: 0.6213038140973344, val acc: 0.7946285605430603

損失関数の学習経過をグラフにしてみます。

plt.style.use("ggplot")
plt.plot(losses, label="train loss")
plt.plot(val_losses, label="test loss")
plt.legend()

以下のように、徐々に損失が小さくなっていることが分かります。

次は正解率をグラフにして見てみます。

plt.plot(accs, label="train acc")
plt.plot(val_accs, label="test acc")
plt.legend()

こちらも徐々に正解率が上がっていくことが分かります。

最後に

ここまでで「CNNでデータ拡張をしながら、CIFAR10を分類する」サンプルコードを実装しました、

PyTorchであれば結構簡単に実装が出来ることを感じられたと思います。

ここまでありがとうございました!

最新情報をチェックしよう!