アクセスありがとうございます!次は「とあるエンジニアのエソラゴト」で検索して頂けると嬉しいです!

【PyTorch】CNNでデータ拡張をしながら、CIFAR10を分類するサンプルコード【初心者】

どうも、エンジニアのエソラ(@ya6madev)です。

普段はSIer企業でDXとかAI開発をしながら、自社サービスの開発をしています。

はじめに

冬休みということでDeep Learningを勉強しています。

今回は以下に挑戦します!

  • CNNを使って、CIFAR10の画像データを分類
  • 精度を上げるために、データ拡張を実施して訓練データ数をかさ増し

では、やっていきます!

前提

PyTorchを使います。

開発環境は「Google Colaboratory」を使っていきます。

かなり重い処理になりますので、GPUを利用することをオススメします!

AWSのSageMakerを使ってGPUを利用する場合は、以下をご参照下さい。

関連記事

[afTag id=7893] 私 冬休みということで、Deep Learningを集中的に勉強しています。 普段は「Google Colaboratory」を使っていますが、無料使用だとGPU使用に制限があるため、すぐにG[…]

テストデータはCIFAR10という画像データを使います。

CIFAR-10とは、以下のdog(犬)やship(船)といった画像データを合計で10個集めて、正解ラベルをつけたデータセットです。
  1. plane(飛行機)
  2. car(車)
  3. bird(鳥)
  4. cat(猫)
  5. deer(鹿)
  6. dog(犬)
  7. frog(カエル)
  8. horse(馬)
  9. ship(船)
  10. truck(トラック)
MNISTという手書き文字データよりも複雑なものが画像データとして写っているので、学習難易度は上がる傾向にあります。

コードの全体はGitHubにアップロードしています。

サンプルコードの全体はこちら。

CIFAR10のデータ準備の実装サンプルコード

ライブラリのインポート、GPU使用

まずは必要なライブラリを、インポートします。

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

そしてGPUが使えたら使いたいので、「GPUが使えるのであればGPUを使うように、そうでなければCPUを使うように」という設定をします。

device = "cuda" if torch.cuda.is_available() else "cpu"

データの変換

そして、以下のコードでデータを変換する処理をコーディングします。

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))                                
])

transforms.ToTensor()はTensor型に変換をしています。

transforms.Normalize((0.5,), (0.5,))は第1引数が平均、第2引数が標準偏差を表し、それに合わせて正規化するように変換をします。

データ拡張

データ拡張とは、「手元で使える画像データが少ない場合に、プログラムからデータを傾けたり、色彩を変化させてデータを擬似的に増やす手法」です。

Deep Learningの分野ではデータ量は多ければ多いほど、良い精度が出る傾向にあるので、データ量を増やすことは非常に大切です。

ただ、集めることが大変な時もあると思うので、是非覚えておきたい手法です。

やり方は簡単で、上記に実装したtransforms.Composeにコードを追加するだけです。

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    # チャネル毎に画像全体の平均値と標準偏差が0.5になるように標準化
    transforms.Normalize((0.5,), (0.5,))                                
])

transforms.RandomHorizontalFlip()は、「ランダムに画像データを水平反転」をします。

transforms.ColorJitter()は、「ランダムに画像の明るさ、コントラスト、彩度、色相を変える」ことが出来ます。

transforms.RandomRotation()は、「ランダムに回転をさせる」ことが出来ます。

CIFAR10のダウンロード

そして、CIFAR10のデータを用意します。

CIFAR10のサンプルデータはPyTorchのライブラリを使って、インターネット経由でダウンロードしてきます。

今回は実際に訓練を行う訓練データと、検証用に使用するテストデータの2つに分けます。

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=val_transform)

そして、以下のようにデータローダーにそれぞれのデータを代入します。

# 訓練用データ
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# テスト用データ
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

データの中身を見てみます。

data_iter = iter(train_dataloader)
imgs, labels = data_iter.next()

ラベルを出すると、以下のように数値の正解ラベルが入っています。

labels

# 以下は出力値
tensor([5, 9, 8, 4, 3, 4, 4, 7, 2, 6, 8, 0, 6, 0, 7, 2, 3, 6, 4, 8, 0, 7, 8, 6,
        5, 3, 8, 5, 9, 0, 4, 1])

では、写真を出力してみます。

img = imgs[0]
img_permute = img.permute(1, 2, 0)
img_permute = 0.5 * img_permute + 0.5
img_permute = np.clip(img_permute, 0, 1)
plt.imshow(img_permute)

鹿の画像が出ました!

CNNモデルの実装サンプルコード

では、ここからCNNのモデル実装をしていきます。

モデルの全体像は以下のようになります。

class CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.classifier = nn.Linear(in_features=4 * 4 * 128, out_features=num_classes)
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

nn.Conv2dで畳み込みフィルターを行います。

活性化関数にはReLUを使用します。

nn.MaxPool2dでマックスプーリングを行います。

model = CNN(10)
model.to(device)

損失関数はクロスエントロピーを選択し、最適化関数はAdamを使用します。

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

CNNモデルの訓練処理の実装サンプルコード

それでは、学習をさせます。

最初に訓練データでモデルを学習させて、テストデータで検証をします。

num_epochs = 15
losses = []
accs = []
val_losses = []
val_accs = []
for epoch in range(num_epochs):
    running_loss = 0.0
    running_acc = 0.0
    for imgs, labels in train_dataloader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        output = model(imgs)
        loss = criterion(output, labels)
        loss.backward()
        running_loss += loss.item()
        pred = torch.argmax(output, dim=1)
        running_acc += torch.mean(pred.eq(labels).float())
        optimizer.step()
    running_loss /= len(train_dataloader)
    running_acc /= len(train_dataloader)
    losses.append(running_loss)
    accs.append(running_acc)
    #
    # test loop
    #
    val_running_loss = 0.0
    val_running_acc = 0.0
    for val_imgs, val_labels in test_dataloader:
        val_imgs = val_imgs.to(device)
        val_labels = val_labels.to(device)
        val_output = model(val_imgs)
        val_loss = criterion(val_output, val_labels)
        val_running_loss += val_loss.item()
        val_pred = torch.argmax(val_output, dim=1)
        val_running_acc += torch.mean(val_pred.eq(val_labels).float())
    val_running_loss /= len(test_dataloader)
    val_running_acc /= len(test_dataloader)
    val_losses.append(val_running_loss)
    val_accs.append(val_running_acc)
    print("epoch: {}, loss: {}, acc: {}, \
     val loss: {}, val acc: {}".format(epoch, running_loss, running_acc, val_running_loss, val_running_acc))

すると、以下のような結果になります。

epoch: 0, loss: 1.433305179813468, acc: 0.47696736454963684,      val loss: 1.0972439880950002, val acc: 0.6093250513076782
epoch: 1, loss: 1.0305320225460592, acc: 0.6342170238494873,      val loss: 0.9065422707091505, val acc: 0.682707667350769
epoch: 2, loss: 0.8761650123469584, acc: 0.6921985149383545,      val loss: 0.8051631742011244, val acc: 0.7189496755599976
epoch: 3, loss: 0.7819633651870379, acc: 0.7235084772109985,      val loss: 0.7636748775125692, val acc: 0.7327276468276978
epoch: 4, loss: 0.7313383280353827, acc: 0.7460412979125977,      val loss: 0.7061512276006583, val acc: 0.7598841786384583
epoch: 5, loss: 0.6902052811034124, acc: 0.7585172653198242,      val loss: 0.710792664140939, val acc: 0.7541933059692383
epoch: 6, loss: 0.6601356580824861, acc: 0.7694937586784363,      val loss: 0.6813203324905981, val acc: 0.769069492816925
epoch: 7, loss: 0.6379130195709505, acc: 0.7775911688804626,      val loss: 0.653992762390417, val acc: 0.776457667350769
epoch: 8, loss: 0.6179025139533322, acc: 0.7821897268295288,      val loss: 0.6546257712399236, val acc: 0.7769568562507629
epoch: 9, loss: 0.5969862834868032, acc: 0.7906469702720642,      val loss: 0.6211730434586065, val acc: 0.7905351519584656
epoch: 10, loss: 0.5893858630288814, acc: 0.7949455976486206,      val loss: 0.634566101212852, val acc: 0.780451238155365
epoch: 11, loss: 0.5731508998118069, acc: 0.8000239729881287,      val loss: 0.6467042784816541, val acc: 0.7775558829307556
epoch: 12, loss: 0.5638545042264942, acc: 0.8032429814338684,      val loss: 0.6207450504024951, val acc: 0.7892372012138367
epoch: 13, loss: 0.5516865827758115, acc: 0.8065218925476074,      val loss: 0.6165681539442592, val acc: 0.7894368767738342
epoch: 14, loss: 0.5429838309666322, acc: 0.8096808791160583,      val loss: 0.6213038140973344, val acc: 0.7946285605430603

損失関数の学習経過をグラフにしてみます。

plt.style.use("ggplot")
plt.plot(losses, label="train loss")
plt.plot(val_losses, label="test loss")
plt.legend()

以下のように、徐々に損失が小さくなっていることが分かります。

次は正解率をグラフにして見てみます。

plt.plot(accs, label="train acc")
plt.plot(val_accs, label="test acc")
plt.legend()

こちらも徐々に正解率が上がっていくことが分かります。

最後に

ここまでで「CNNでデータ拡張をしながら、CIFAR10を分類する」サンプルコードを実装しました、

PyTorchであれば結構簡単に実装が出来ることを感じられたと思います。

ここまでお読み頂き、ありがとうございました。

もし、「面白かった」、「参考になった」という方がいましたら、以下のソーシャルボタンからシェア頂けると泣いて喜びます!!

またブログランキングにも参加しています。

よろしければポチッとお願いします!

それでは、良いエンジニアライフをお過ごし下さい!

Aidemy Premium Plan

あなたは、AIエンジニアになりたいと思ったけど、以下のように感じたことはありませんか?

  • どうやって勉強したら良いのか分からない!
  • 何を勉強したら良いか分からない!
  • 一人で勉強を続けられる気がしない!

このように感じたら、AI(人工知能)特化型プログラミングスクールである「Aidemy Premium Plan」がおすすめです。


Aidemy Premium Planは「マンツーマン指導によって、3ヶ月で即使えるAIスキルを」というコンセプトで初心者からでも無理なく、

AIエンジニアになるために必要な学習が用意されています。

AIに仕事が奪われる時代がもうすぐ来ると言われていますが、今のうちにAIエンジニアになっておけば、「AIを作る側に回ること」が出来ます。

興味がある方は無料のビデオカウンセリングに申し込んで見ることをおすすめします。

またAidemy Premium Planは教育訓練給付金認定講座のため、受講する時は必ずキャッシュバックを受けるようにお気をつけ下さい!

\無料ビデオカウンセリングに申し込む/

 

最新情報をチェックしよう!