【機械学習②】Kerasでデータ拡張をした時に画像がどのように変化するのか確認してみた

スポンサーリンク
使ってみた
スポンサーリンク
スポンサーリンク

概要

 Pythonの深層学習ライブラリKerasにおいてデータ拡張(Data Augmentation)を適用する際、画像がどのような変形を受けているのか、実際に適用後の画像を見て確認してみたいと思います。元画像にはCIFAR-10を用います。

そもそもデータ拡張(Data Augmentation)とは?

 深層学習(Deep Learning)のモデルを訓練する時には、大量のデータが必要になります。しかし、その際の画像データが足りない時には、画像を回転したり、縮小したり、もしくはRGB値を変更したりするなどの加工を加えることで、画像数を水増しします。これをデータ拡張と呼びます。

Kerasでデータ拡張

 Kerasでは、データ拡張は関数ImageDataGenerator.flowで行われます。この関数の返り値はPythonのイテレータなので、for文のinの後ろに置いてループを回せます。ただ、際限なく値を返すので、適切な回数でbreakが必要です。

以下のコードはGoogle Colaboratory上での実行を想定しています。

ライブラリをインポートしてMatplotlibを日本語化

 まずは、Keras、Numpy、Matplotlib等のライブラリをインポートして、さらにMatplotlibを日本語化します。

# Kerasその他ライブラリをインポート
import keras
import numpy as np
import matplotlib.pyplot as plt
from keras.preprocessing.image import ImageDataGenerator

# Matplotlibの日本語化
!pip install japanize-matplotlib
import japanize_matplotlib

以下は、Google Colaboratoryで実行した様子です。

弱めのaugmentationと強めのaugmentationを実施

 次に、CIFAR-10をロードして日本語でラベル付けした後、指定した番号の画像に対して弱めのaugmentationと強めのaugmentationを実施します。

augmentationの度合いは以下の項目で指定できます。これらの値を変更すると、画像の加工具合が変わってきます。

rotation_range指定した範囲の角度でランダムに回転
horizontal_flip水平方向にランダムに反転
vertical_flip垂直方向にランダムに反転
height_shift_range指定した範囲の割合でランダムに垂直移動
width_shift_range指定した範囲の割合でランダムに水平移動
zoom_range指定した範囲の割合でランダムにズーム
shear_range指定した範囲の割合でランダムにシアー(剪断)をかける
channel_shift_range指定した範囲でランダムにチャネルの値(色)を変更
# CIFAR-10のロード
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
# 日本語でラベル付け
cifar10_labels = np.array(['飛行機','クルマ','トリ','ネコ','シカ','イヌ','カエル','ウマ','船','トラック'])

# 画像番号[0〜49999で変えてみてください]
n = 40

# 弱めのaugmentation
augmented_data = ImageDataGenerator(rotation_range=20, horizontal_flip=0.2, vertical_flip=0.2, height_shift_range=0.2,
                                width_shift_range=0.2, zoom_range=0.2, channel_shift_range=20
                                ).flow(x_train[n:n+1], y_train[n:n+1], batch_size=1)

# 強めのaugmentation
strongly_augmented_data = ImageDataGenerator(rotation_range=50, horizontal_flip=0.3, vertical_flip=0.3, height_shift_range=0.5,
                                width_shift_range=0.5, zoom_range=0.5, shear_range=0.5, channel_shift_range=100
                                ).flow(x_train[n:n+1], y_train[n:n+1], batch_size=1)

# 表示行数
rows = 2
# 表示列数
cols = 7

# 元画像の表示
plt.figure(figsize=(16,10))
plt.subplot(rows*2, cols, 1)
plt.imshow(x_train[n])
plt.title('元画像 ' + cifar10_labels[y_train[n][0]])

pos = 2 # プロットの位置インデックス

# 弱めのaugmentationの可視化
for data in augmented_data:
    plt.subplot(rows*2, cols, pos)
    # augment後の各ピクセルの値はfloat32になってしまっているので, uint8(符号無し8ビット整数値)にキャストしてRGB値0〜255に切り捨て
    plt.imshow(data[0][0].astype('uint8'))
    plt.axis('off')
    plt.title('弱拡張 ' + cifar10_labels[data[1][0][0]])
    pos += 1
    if pos >= rows * cols +1:
        break

# 強めのaugmentationの可視化
for data in strongly_augmented_data:
    plt.subplot(rows*2, cols, pos)
    plt.imshow(data[0][0].astype('uint8'))
    plt.axis('off')
    plt.title('強拡張 '+ cifar10_labels[data[1][0]][0])
    pos += 1
    if pos >= rows * cols *2 + 1:
        break

plt.show()

以下は、Google Colaboratoryで実行した様子です。

このコードを実行すると、CIFAR-10の40番目に登録されている犬の画像が、2段階の強さでデータ拡張された状態で表示されます。

43番目の馬だとこんな感じになります。

コメント

プロフィール

bond.fatherは、主にAI、Python、バイオインフォマティクスに関わっています。
bond.motherは、当ブログの管理者です。インフラから開発まで広く勉強中です。
当記事はbond.fatherが作成し、bond.motherが編集をしています。

bond.father & motherをフォローする
タイトルとURLをコピーしました