如何使用Tensorflow预处理花卉训练数据集?

如何使用Tensorflow预处理花卉训练数据集?

花卉识别一直是计算机视觉领域的经典问题,而深度学习技术为此提供了不错的解决方式。Tensorflow作为当前流行的深度学习框架之一,也可以用于解决花卉识别问题。本篇文章将介绍如何使用Tensorflow预处理花卉训练数据集,为花卉识别模型的训练打下坚实的数据基础。

更多Python文章,请阅读:Python 教程

数据集简介

本次实践使用的是Tensorflow的官方花卉数据集,这个数据集包含了5种花卉,分别是日本蜡梅、万寿菊、黄鹂花、银莲花和玫瑰花,每个类别下面各有约80张图片。每张图片都拥有相同的尺寸,即256×256像素。

代码实现

首先,我们需要下载花卉数据集并解压到相应的文件夹,比如:~/flower_photos/。在~/flower_photos/文件夹下,一共有5个子文件夹,每个子文件夹包含属于对应类别的花卉图片。接下来,我们对训练数据集进行预处理,得到模型训练所需的tfrecord数据。

import tensorflow as tf
from tensorflow.python.platform import gfile
import os

# 定义tfrecord文件存储路径
OUTPUT_DIR = 'output_directory'
# 花卉数据集目录
FLOWER_DATA_DIR = 'flower_photos'

# 获取所有图片路径以及对应的标签
def create_image_lists(sess):
    sub_dirs = [x[0] for x in os.walk(FLOWER_DATA_DIR)]
    is_root_dir = True
    num_classes = 0
    image_lists = {}
    for sub_dir in sub_dirs:
        if is_root_dir:
            is_root_dir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        file_list = []
        dir_name = os.path.basename(sub_dir)
        if dir_name == FLOWER_DATA_DIR:
            continue
        tf.logging.info('Processing %s' % dir_name)
        for extension in extensions:
            file_glob = os.path.join(FLOWER_DATA_DIR, dir_name, '*.' + extension)
            file_list.extend(gfile.Glob(file_glob))
        if not file_list:
            continue
        label_name = dir_name.lower()
        training_images = []
        testing_images = []
        validation_images = []
        for file_name in file_list:
            base_name = os.path.basename(file_name)
            chance = np.random.randint(100)
            if chance < 80:
                training_images.append(base_name)
            elif chance < 90:
                testing_images.append(base_name)
            else:
                validation_images.append(base_name)
        label_index = num_classes
        num_classes += 1
        image_lists[label_name] = {
                  'dir': dir_name,
                  'training': training_images,
                  'testing': testing_images,
                  'validation': validation_images,
                }
    return image_lists


# 获取图片的绝对路径
def get_image_path(image_lists, label_name, index, image_dir, category):
    label_lists = image_lists[label_name]
    category_list = label_lists[category]
    mod_index = index % len(category_list)
    base_name = category_list[mod_index]
    sub_dir = label_lists['dir']
    full_path = os.path.join(image_dir, sub_dir, base_name)
    return full_path


# 获取图片的标签
def get_label(image_lists, label_name):
    return image_lists[label_name]['label']


# 获取数据集中的所有标签
def get_labels(image_lists):
    labels = []
    for label_name in list(image_lists.keys()):
        labels.append(get_label(image_lists, label_name))
    return labels


def _convert_dataset(split_name, filenames, image_lists, image_dir, output_directory):
    assert split_name in ['train', 'test', 'validation']
    num_shards = 4
    output_filename= os.path.join(output_directory, 'flowers_%s_%.2d-of-%.2d.tfrecord' % (split_name, 0, num_shards))
    with tf.Graph().as_default():
        image_reader = ImageReader()
        with tf.Session('') as sess:
            for shard_id in range(num_shards):
                output_filename = os.path.join(output_directory, 'flowers_%s_%.2d-of-%.2d.tfrecord' % (split_name, shard_id, num_shards))
                with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
                    start_ndx = shard_id * len(filenames) // num_shards
                    end_ndx = (shard_id + 1) * len(filenames) // num_shards
                    for i in range(start_ndx, end_ndx):
                        sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i + 1, len(filenames), shard_id))
                        sys.stdout.flush()
                        # Read the filename
                        image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
                        height, width = image_reader.read_image_dims(sess, image_data)
                        class_name = os.path.basename(os.path.dirname(filenames[i]))
                        class_id = image_lists.keys().index(class_name)
                        example = image_to_tfexample(image_data, b'jpg', height, width, class_id)
                        tfrecord_writer.write(example.SerializeToString())
    sys.stdout.write('\n')
    sys.stdout.flush()


class ImageReader(object):

    def __init__(self):
        self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
        self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)

    def read_image_dims(self, sess, image_data):
        image = self.decode_jpeg(sess, image_data)
        return image.shape[0], image.shape[1]

    def decode_jpeg(self, sess, image_data):
        image = sess.run(self._decode_jpeg,
                         feed_dict={self._decode_jpeg_data: image_data})
        assert len(image.shape) == 3
        assert image.shape[2] == 3
        return image


# 将图片处理成tfrecord格式
def create_tfrecord_file(image_lists, image_dir, output_directory):
    for label_name, data in image_lists.items():
        train_data = data['training']
        testing_data = data['testing']
        validation_data = data['validation']
        _convert_dataset('train', [get_image_path(image_lists, label_name, index, image_dir, 'training') for index in range(len(train_data))], image_lists, image_dir, output_directory)
        _convert_dataset('test', [get_image_path(image_lists, label_name, index, image_dir, 'testing') for index in range(len(testing_data))], image_lists, image_dir, output_directory)
        _convert_dataset('validation', [get_image_path(image_lists, label_name, index, image_dir, 'validation') for index in range(len(validation_data))], image_lists, image_dir, output_directory)

其中,file_list变量储存在每个类别文件夹中的所有文件名,并根据题目要求把按照80%:10%:10%的比例分割为训练、测试和验证集。然后我们将花卉图片数据处理成tfrecord格式,方便后续调用,这里我们使用的是官方提供的如下辅助函数:

def image_to_tfexample(image_data, image_format, height, width, class_id):
    return tf.train.Example(features=tf.train.Features(feature={
          'image/encoded': _bytes_feature(image_data),
          'image/format': _bytes_feature(image_format),
          'image/class/label': _int64_feature(class_id),
          'image/height': _int64_feature(height),
          'image/width': _int64_feature(width),
      }))


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

结论

本文详细讲述了如何使用Tensorflow预处理花卉训练数据集。经过预处理后,我们得到了tfrecord格式的数据集,为花卉识别模型的训练打下坚实的数据基础。本篇文章的代码如下所示:

Camera课程

Python教程

Java教程

Web教程

数据库教程

图形图像教程

办公软件教程

Linux教程

计算机教程

大数据教程

开发工具教程