如何为Tensorflow配置花卉数据集以提高性能?
TensorFlow 是一个非常流行的深度学习框架,它有着广泛的可定制性以及强大的性能。TensorFlow 可以处理几乎所有类型的数据,包括图像、文本、语音以及视频等等。而对于处理图像数据,花卉数据集是一个非常常见的任务,本文将介绍如何为Tensorflow配置花卉数据集以提高性能。
更多Python文章,请阅读:Python 教程
花卉数据集
花卉数据集是一个经典的图像分类数据集,它由 102 类不同种类的花朵图像构成,每个类别拥有至少 40 张图像。每张图像的大小为 320×240 像素。花卉数据集是一个很好的开始,因为它既适合新手,也适合专业人员。
本文将使用 TensorFlow 中的 tf.keras 库来处理花卉数据集。我们需要先下载数据集并将其准备好以后,才能在 TensorFlow 中使用相关库进行训练和测试。
下载
花卉数据集可以从 UCI 机器学习资源库上下载。
下载地址:http://archive.ics.uci.edu/ml/datasets/102+Flowers
或者,可以使用 TensorFlow Datasets(tfds)包下载:
import tensorflow_datasets as tfds
(train_ds, test_ds), metadata = tfds.load(
'tf_flowers',
split=['train[:80%]', 'train[80%:]', 'test'],
with_info=True,
as_supervised=True,
)
预处理
在下载并准备好数据集后,我们应该进行预处理以便于在 TensorFlow 中使用。图像需要被规范化和重新缩放为标准大小。
IMG_SIZE = 224
def format_image(image, label):
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))/255.0
return image, label
BATCH_SIZE = 32
train_ds = train_ds.map(format_image).batch(BATCH_SIZE)
test_ds = test_ds.map(format_image).batch(BATCH_SIZE)
在传递数据集时,我们可以在 TensorFlow 中使用 map 函数将格式化的图像和标签设置为一个批次大小。
数据增强
我们可以通过在训练期间人为地向数据集添加变换或扭曲来增强数据集。
data_augmentation = Sequential([
tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])
以上代码是一个数据增强器,它将水平翻转,随机旋转一些图片。
for image, label in train_ds.take(1):
plt.imshow(image[0])
plt.title(get_label_name(label[0]))
plt.show()
plt.imshow(data_augmentation(image)[0])
plt.title(get_label_name(label[0]))
plt.show()
将图像传递到增强器时,我们可以看到新的图像与原始图像不同。
建模
现在我们已经准备好对花卉数据集进行训练。我们将使用预处理后的数据集和增强器来训练模型。
base_model = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(IMG_SIZE, IMG_SIZE, 3))
base_model.trainable = False
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
prediction_layer = tf.keras.layers.Dense(metadata.features['label'].num_classes, activation='softmax')
model = tf.keras.Sequential([
data_augmentation,
base_model,
global_average_layer,
prediction_layer
])
我们在这里使用 MobileNet V2 模型进行预测。MobileNet V2 是一种轻量级卷积神经网络,可用于图像分类和目标检测等任务。可以通过设置 train_model 为 False 使它不参与训练,同时在模型顶部添加全局平均池化层和密集层。这种方法可以帮助我们在不牺牲精度的情况下使训练时间更短。我们还将密集层的激活函数设置为 softmax,这是因为花卉数据集是一个多类别分类问题。
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.summary()
训练和评估
我们可以在训练和测试集上进行模型训练和评估。训练过程可以通过调用 fit 函数来完成。
initial_epochs = 10
loss0, accuracy0 = model.evaluate(test_ds)
history = model.fit(train_ds,
epochs=initial_epochs,
validation_data=test_ds)
loss, accuracy = model.evaluate(test_ds)
print("Loss: ", loss)
print("Accuracy: ", accuracy)
我们可以通过运行以下代码来可视化训练过程。
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(initial_epochs)
plt.figure(figsize=(12, 12))
plt.subplot(2, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(2, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
结论
在这篇文章中,我们已经介绍了如何为 TensorFlow 配置花卉数据集以提高性能。我们学习了如何下载、预处理、数据增强、建模、训练和评估。希望这些技巧对你有所帮助并且可以启发你在 TensorFlow 中进行更多有趣的项目。