1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
| def _generate_image_and_label_batch(image, label, min_queue_examples, batch_size, shuffle):
# 创建一个混排样本的队列,然后从样本队列中读取 'batch_size'数量的 images + labels数据(每个样本都是由images + labels组成) num_preprocess_threads = 16 if shuffle: images, label_batch = tf.train.shuffle_batch( [image, label], batch_size=batch_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + 3 * batch_size, min_after_dequeue=min_queue_examples) else: images, label_batch = tf.train.batch( [image, label], batch_size=batch_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + 3 * batch_size)
# Display the training images in the visualizer. tf.image_summary('images', images)
return images, tf.reshape(label_batch, [batch_size])
''' 作用: 使用Reader操作构建扭曲的输入(图像)用作CIFAR训练
@param data_dir: CIFAR-10数据目录 batch_size: 每一批量的图像数 @Returns: images: Images. 尺寸为 [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] 的4D张量 labels: Labels. 大小为[batch_size] 的一维张量 ''' def distorted_inputs(data_dir, batch_size):
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)] for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: ' + f)
#创建一个先进先出的文件名队列,文件阅读器需要它来读取数据 filename_queue = tf.train.string_input_producer(filenames) # 从文件名队列中读取样本 read_input = read_cifar10(filename_queue) reshaped_image = tf.cast(read_input.uint8image, tf.float32)
height = IMAGE_SIZE width = IMAGE_SIZE # 用于训练神经网络的图像处理,注意对图像进行了很多随机扭曲处理 # 随机修建图像的某一块[height, width]区域 distorted_image = tf.random_crop(reshaped_image, [height, width, 3]) #随机水平翻转图像 distorted_image = tf.image.random_flip_left_right(distorted_image) # 由于这些操作都是不可累积的,考虑随机这些操作的顺序 distorted_image = tf.image.random_brightness(distorted_image, max_delta=63) distorted_image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8)
# 减去均值并处以像素的方差 (标准化) float_image = tf.image.per_image_whitening(distorted_image)
# 确保随机混排有很好的混合性 min_fraction_of_examples_in_queue = 0.4 min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue) print ('Filling queue with %d CIFAR images before starting to train. ' 'This will take a few minutes.' % min_queue_examples)
# 通过构建一个样本队列来生成一批量的图像和标签 return _generate_image_and_label_batch(float_image, read_input.label, min_queue_examples, batch_size, shuffle=True)
''' 作用: 使用Reader ops操作构建CIFAR评估的输入
@param: eval_data: boolean类型, 是否使用训练或评估数据集 data_dir: CIFAR-10数据目录. batch_size: 每一批的图像数量 Returns: images: Images. 尺寸为[batch_size, IMAGE_SIZE, IMAGE_SIZE, 3]的4D张量. labels: Labels. 大小为[batch_size]的一维张量. ''' def inputs(eval_data, data_dir, batch_size):
if not eval_data: filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)] num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN else: filenames = [os.path.join(data_dir, 'test_batch.bin')] num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: ' + f)
#创建一个先进先出的文件名队列,文件阅读器需要它来读取数据 filename_queue = tf.train.string_input_producer(filenames)
# 从文件名队列中读取抽样 read_input = read_cifar10(filename_queue) reshaped_image = tf.cast(read_input.uint8image, tf.float32)
height = IMAGE_SIZE width = IMAGE_SIZE
# 用于做评估的图像处理 # 裁减图像的中心 [height, width] Crop the central [height, width] of the image. resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, width, height)
# 减去均值并处以像素的方差(标准化) float_image = tf.image.per_image_whitening(resized_image)
# 确保良好的随机性 min_fraction_of_examples_in_queue = 0.4 min_queue_examples = int(num_examples_per_epoch * min_fraction_of_examples_in_queue)
# Generate a batch of images and labels by building up a queue of examples. return _generate_image_and_label_batch(float_image, read_input.label, min_queue_examples, batch_size, shuffle=False)
|