概要
Pythonの深層学習ライブラリKerasにおいてデータ拡張(Data Augmentation)を適用する際、画像がどのような変形を受けているのか、実際に適用後の画像を見て確認してみたいと思います。元画像にはCIFAR-10を用います。
そもそもデータ拡張(Data Augmentation)とは?
深層学習(Deep Learning)のモデルを訓練する時には、大量のデータが必要になります。しかし、その際の画像データが足りない時には、画像を回転したり、縮小したり、もしくはRGB値を変更したりするなどの加工を加えることで、画像数を水増しします。これをデータ拡張と呼びます。
Kerasでデータ拡張
Kerasでは、データ拡張は関数ImageDataGenerator.flowで行われます。この関数の返り値はPythonのイテレータなので、for文のinの後ろに置いてループを回せます。ただ、際限なく値を返すので、適切な回数でbreakが必要です。
以下のコードはGoogle Colaboratory上での実行を想定しています。
ライブラリをインポートしてMatplotlibを日本語化
まずは、Keras、Numpy、Matplotlib等のライブラリをインポートして、さらにMatplotlibを日本語化します。
# Kerasその他ライブラリをインポート
import keras
import numpy as np
import matplotlib.pyplot as plt
from keras.preprocessing.image import ImageDataGenerator
# Matplotlibの日本語化
!pip install japanize-matplotlib
import japanize_matplotlib
以下は、Google Colaboratoryで実行した様子です。
弱めのaugmentationと強めのaugmentationを実施
次に、CIFAR-10をロードして日本語でラベル付けした後、指定した番号の画像に対して弱めのaugmentationと強めのaugmentationを実施します。
augmentationの度合いは以下の項目で指定できます。これらの値を変更すると、画像の加工具合が変わってきます。
rotation_range | 指定した範囲の角度でランダムに回転 |
horizontal_flip | 水平方向にランダムに反転 |
vertical_flip | 垂直方向にランダムに反転 |
height_shift_range | 指定した範囲の割合でランダムに垂直移動 |
width_shift_range | 指定した範囲の割合でランダムに水平移動 |
zoom_range | 指定した範囲の割合でランダムにズーム |
shear_range | 指定した範囲の割合でランダムにシアー(剪断)をかける |
channel_shift_range | 指定した範囲でランダムにチャネルの値(色)を変更 |
# CIFAR-10のロード
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
# 日本語でラベル付け
cifar10_labels = np.array(['飛行機','クルマ','トリ','ネコ','シカ','イヌ','カエル','ウマ','船','トラック'])
# 画像番号[0〜49999で変えてみてください]
n = 40
# 弱めのaugmentation
augmented_data = ImageDataGenerator(rotation_range=20, horizontal_flip=0.2, vertical_flip=0.2, height_shift_range=0.2,
width_shift_range=0.2, zoom_range=0.2, channel_shift_range=20
).flow(x_train[n:n+1], y_train[n:n+1], batch_size=1)
# 強めのaugmentation
strongly_augmented_data = ImageDataGenerator(rotation_range=50, horizontal_flip=0.3, vertical_flip=0.3, height_shift_range=0.5,
width_shift_range=0.5, zoom_range=0.5, shear_range=0.5, channel_shift_range=100
).flow(x_train[n:n+1], y_train[n:n+1], batch_size=1)
# 表示行数
rows = 2
# 表示列数
cols = 7
# 元画像の表示
plt.figure(figsize=(16,10))
plt.subplot(rows*2, cols, 1)
plt.imshow(x_train[n])
plt.title('元画像 ' + cifar10_labels[y_train[n][0]])
pos = 2 # プロットの位置インデックス
# 弱めのaugmentationの可視化
for data in augmented_data:
plt.subplot(rows*2, cols, pos)
# augment後の各ピクセルの値はfloat32になってしまっているので, uint8(符号無し8ビット整数値)にキャストしてRGB値0〜255に切り捨て
plt.imshow(data[0][0].astype('uint8'))
plt.axis('off')
plt.title('弱拡張 ' + cifar10_labels[data[1][0][0]])
pos += 1
if pos >= rows * cols +1:
break
# 強めのaugmentationの可視化
for data in strongly_augmented_data:
plt.subplot(rows*2, cols, pos)
plt.imshow(data[0][0].astype('uint8'))
plt.axis('off')
plt.title('強拡張 '+ cifar10_labels[data[1][0]][0])
pos += 1
if pos >= rows * cols *2 + 1:
break
plt.show()
以下は、Google Colaboratoryで実行した様子です。
このコードを実行すると、CIFAR-10の40番目に登録されている犬の画像が、2段階の強さでデータ拡張された状態で表示されます。
43番目の馬だとこんな感じになります。
コメント