转载自http://blog.csdn.net/u012759136/article/details/52232266 原文作者github地址 概述关于Tensorflow读取数据,官网给出了三种方法:
对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分 TFRecordsTFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件(等会儿就知道为什么了)… …总而言之,这样的文件格式好处多多,所以让我们用起来吧。 TFRecords文件包含了 从TFRecords文件中读取数据, 可以使用 接下来,让我们开始读取数据之旅吧~ 生成TFRecords文件我们使用 import osimport tensorflow as tf from PIL import Imagecwd = os.getcwd()'''此处我加载的数据目录如下:0 -- img1.jpg img2.jpg img3.jpg ...1 -- img1.jpg img2.jpg ...2 -- ......'''writer = tf.python_io.TFRecordWriter('train.tfrecords')for index, name in enumerate(classes): class_path = cwd + name + '/' for img_name in os.listdir(class_path): img_path = class_path + img_name img = Image.open(img_path) img = img.resize((224, 224)) img_raw = img.tobytes() #将图片转化为原生bytes example = tf.train.Example(features=tf.train.Features(feature={ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) writer.write(example.SerializeToString()) #序列化为字符串writer.close()
关于 基本的,一个 就这样,我们把相关的信息都存到了一个文件中,所以前面才说不用单独的label文件。而且读取也很方便。 for serialized_example in tf.python_io.tf_record_iterator('train.tfrecords'): example = tf.train.Example() example.ParseFromString(serialized_example) image = example.features.feature['image'].bytes_list.value label = example.features.feature['label'].int64_list.value # 可以做一些预处理之类的 print image, label 使用队列读取一旦生成了TFRecords文件,接下来就可以使用队列( def read_and_decode(filename): #根据文件名生成一个队列 filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) #返回文件名和文件 features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string), }) img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [224, 224, 3]) img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 label = tf.cast(features['label'], tf.int32) return img, label 之后我们可以在训练的时候这样使用 img, label = read_and_decode('train.tfrecords')#使用shuffle_batch可以随机打乱输入img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=30, capacity=2000, min_after_dequeue=1000)init = tf.initialize_all_variables()with tf.Session() as sess: sess.run(init) threads = tf.train.start_queue_runners(sess=sess) for i in range(3): val, l= sess.run([img_batch, label_batch]) #我们也可以根据需要对val, l进行处理 #l = to_categorical(l, 12) print(val.shape, l) 至此,tensorflow高效从文件读取数据差不多完结了。 恩?等等…什么叫差不多?对了,还有几个注意事项: 第一,tensorflow里的graph能够记住状态( 第二,tensorflow中的队列和普通的队列差不多,不过它里面的 第三, 总结
|
|
来自: MaysThree > 《tensorflow》