如何使用Tensorflow预处理花卉训练数据集?
花卉识别一直是计算机视觉领域的经典问题,而深度学习技术为此提供了不错的解决方式。Tensorflow作为当前流行的深度学习框架之一,也可以用于解决花卉识别问题。本篇文章将介绍如何使用Tensorflow预处理花卉训练数据集,为花卉识别模型的训练打下坚实的数据基础。
更多Python文章,请阅读:Python 教程
数据集简介
本次实践使用的是Tensorflow的官方花卉数据集,这个数据集包含了5种花卉,分别是日本蜡梅、万寿菊、黄鹂花、银莲花和玫瑰花,每个类别下面各有约80张图片。每张图片都拥有相同的尺寸,即256×256像素。
代码实现
首先,我们需要下载花卉数据集并解压到相应的文件夹,比如:~/flower_photos/
。在~/flower_photos/
文件夹下,一共有5个子文件夹,每个子文件夹包含属于对应类别的花卉图片。接下来,我们对训练数据集进行预处理,得到模型训练所需的tfrecord数据。
import tensorflow as tf
from tensorflow.python.platform import gfile
import os
# 定义tfrecord文件存储路径
OUTPUT_DIR = 'output_directory'
# 花卉数据集目录
FLOWER_DATA_DIR = 'flower_photos'
# 获取所有图片路径以及对应的标签
def create_image_lists(sess):
sub_dirs = [x[0] for x in os.walk(FLOWER_DATA_DIR)]
is_root_dir = True
num_classes = 0
image_lists = {}
for sub_dir in sub_dirs:
if is_root_dir:
is_root_dir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
file_list = []
dir_name = os.path.basename(sub_dir)
if dir_name == FLOWER_DATA_DIR:
continue
tf.logging.info('Processing %s' % dir_name)
for extension in extensions:
file_glob = os.path.join(FLOWER_DATA_DIR, dir_name, '*.' + extension)
file_list.extend(gfile.Glob(file_glob))
if not file_list:
continue
label_name = dir_name.lower()
training_images = []
testing_images = []
validation_images = []
for file_name in file_list:
base_name = os.path.basename(file_name)
chance = np.random.randint(100)
if chance < 80:
training_images.append(base_name)
elif chance < 90:
testing_images.append(base_name)
else:
validation_images.append(base_name)
label_index = num_classes
num_classes += 1
image_lists[label_name] = {
'dir': dir_name,
'training': training_images,
'testing': testing_images,
'validation': validation_images,
}
return image_lists
# 获取图片的绝对路径
def get_image_path(image_lists, label_name, index, image_dir, category):
label_lists = image_lists[label_name]
category_list = label_lists[category]
mod_index = index % len(category_list)
base_name = category_list[mod_index]
sub_dir = label_lists['dir']
full_path = os.path.join(image_dir, sub_dir, base_name)
return full_path
# 获取图片的标签
def get_label(image_lists, label_name):
return image_lists[label_name]['label']
# 获取数据集中的所有标签
def get_labels(image_lists):
labels = []
for label_name in list(image_lists.keys()):
labels.append(get_label(image_lists, label_name))
return labels
def _convert_dataset(split_name, filenames, image_lists, image_dir, output_directory):
assert split_name in ['train', 'test', 'validation']
num_shards = 4
output_filename= os.path.join(output_directory, 'flowers_%s_%.2d-of-%.2d.tfrecord' % (split_name, 0, num_shards))
with tf.Graph().as_default():
image_reader = ImageReader()
with tf.Session('') as sess:
for shard_id in range(num_shards):
output_filename = os.path.join(output_directory, 'flowers_%s_%.2d-of-%.2d.tfrecord' % (split_name, shard_id, num_shards))
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_ndx = shard_id * len(filenames) // num_shards
end_ndx = (shard_id + 1) * len(filenames) // num_shards
for i in range(start_ndx, end_ndx):
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i + 1, len(filenames), shard_id))
sys.stdout.flush()
# Read the filename
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
height, width = image_reader.read_image_dims(sess, image_data)
class_name = os.path.basename(os.path.dirname(filenames[i]))
class_id = image_lists.keys().index(class_name)
example = image_to_tfexample(image_data, b'jpg', height, width, class_id)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
class ImageReader(object):
def __init__(self):
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def read_image_dims(self, sess, image_data):
image = self.decode_jpeg(sess, image_data)
return image.shape[0], image.shape[1]
def decode_jpeg(self, sess, image_data):
image = sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
# 将图片处理成tfrecord格式
def create_tfrecord_file(image_lists, image_dir, output_directory):
for label_name, data in image_lists.items():
train_data = data['training']
testing_data = data['testing']
validation_data = data['validation']
_convert_dataset('train', [get_image_path(image_lists, label_name, index, image_dir, 'training') for index in range(len(train_data))], image_lists, image_dir, output_directory)
_convert_dataset('test', [get_image_path(image_lists, label_name, index, image_dir, 'testing') for index in range(len(testing_data))], image_lists, image_dir, output_directory)
_convert_dataset('validation', [get_image_path(image_lists, label_name, index, image_dir, 'validation') for index in range(len(validation_data))], image_lists, image_dir, output_directory)
其中,file_list
变量储存在每个类别文件夹中的所有文件名,并根据题目要求把按照80%:10%:10%的比例分割为训练、测试和验证集。然后我们将花卉图片数据处理成tfrecord格式,方便后续调用,这里我们使用的是官方提供的如下辅助函数:
def image_to_tfexample(image_data, image_format, height, width, class_id):
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': _bytes_feature(image_data),
'image/format': _bytes_feature(image_format),
'image/class/label': _int64_feature(class_id),
'image/height': _int64_feature(height),
'image/width': _int64_feature(width),
}))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
结论
本文详细讲述了如何使用Tensorflow预处理花卉训练数据集。经过预处理后,我们得到了tfrecord格式的数据集,为花卉识别模型的训练打下坚实的数据基础。本篇文章的代码如下所示: