我们首先说一下文件读取的流程(分别讨论文本文件、图片文件以及二进制文件):
?1、构造文件名队列
file_queue= tf.train.string_input_producer(string_tensor,shuffle = True)
2、读取与解码
文本:
? ? ? ? 读取:tf.TextLineReader()
? ? ? ? 解码:tf.decode_csv()
图片:
? ? ? ? 读取:tf.WholeFileReader()
? ? ? ? 解码:tf.image.decode_jepg(contents)
? ? ? ? ? ? ? ? ? ?tf.image.decode_png(contents)
二进制文件:
? ? ? ? ?读取:tf.FixedLengthRecordReader(record_bytes) ? ? ? ? ?解码:tf.decode_raw()
3、批处理队列
tf.train.batch(tensors, batch_size, num_threads = 1, capacity = 32, name=None)
手动开启线程:coord = tf.train.Coordinator()
开启会话:tf.train.start_queue_runners(sess=None, coord=None)
回收线程:coord.request_stop()? ?coord.join(threads)
实例一:狗的图片读取
我这里使用的是100张狗的图片,图片类型是jpg,因为每个图片的规格都是不同的,因此加入了reshape的步骤。
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
# 图片数据
# 图片 - 数值(三维数组shape(图片长度、图片宽度、图片通道数))
# 图片三要素
# 灰度图 [长,宽,1] 每一个像素点[0,255]
# 彩色图 [长,宽,3] 每一个像素点[0,255]
def picture_read(file_list):
# 1.构造文件名队列
file_queue = tf.train.string_input_producer(file_list)
# 2.读取与解码
reader = tf.WholeFileReader()
# key文件名 value一张图片的原始编码形式
key,value = reader.read(file_queue)
print(key,value)
# 解码阶段
image = tf.image.decode_jpeg(value)
print(image)
# 图像的形状、类型修改
image_resized = tf.image.resize_images(image,size = [200,200])
print(image_resized)
#静态形状修改
image_resized.set_shape(shape=[200,200,3])
# 3.批处理
image_batch = tf.train.batch([image_resized],batch_size=100,num_threads=1,capacity=100)
print("image_batch:",image_batch)
# 开启会话
with tf.Session() as sess:
# 开启线程
# 线程协调员
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
key_new,value_new,image_new,image_resized,image_batch = sess.run([key,value,image,image_resized,image_batch])
print("key_new:",key_new)
# print("value_new:",value_new)
print("image_new",image_new)
print("image_resized",image_resized)
print("image_batch",image_batch)
# 回收线程
coord.request_stop()
coord.join(threads)
if __name__ == "__main__":
filename = os.listdir("./dog")
# print(filename)
# 拼接路径 + 文件名
file_list = [os.path.join("./dog/",file) for file in filename]
# print(file_list)
picture_read(file_list)
示例二:读取二进制文件
本示例采用的是CIFAR-10 binary version (suitable for C programs)
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
class Cifer(object):
def __init__(self):
# 初始化操作
self.height = 32
self.width = 32
self.channels = 3
# 字节数
self.image_bytes = self.height * self.width * self.channels
self.label_bytes = 1
self.all_bytes = self.label_bytes + self.image_bytes
def read_and_decode(self,file_list):
# 构造文件名队列
file_queue = tf.train.string_input_producer(file_list)
# 读取与解码
# 读取阶段
reader = tf.FixedLengthRecordReader(self.all_bytes)
# key文件名 value一个样本
key,value = reader.read(file_queue)
print("key:",key)
print("value:",value)
# 解码阶段
decoded = tf.decode_raw(value,tf.uint8)
print("decoded:",decoded)
# 将目标值和特征值切片切开
label = tf.slice(decoded,[0],[self.label_bytes])
image = tf.slice(decoded,[self.label_bytes],[self.image_bytes])
print("label:",label)
print("image:",image)
# 调整图片形状
image_reshaped = tf.reshape(image,shape=[self.channels,self.height,self.width])
print("image_reshaped:",image_reshaped)
# 转置,将图片的顺序旋转为height,width,channels
image_transposed = tf.transpose(image_reshaped,[1,2,0])
print("image_transposed",image_transposed)
# 调整图像类型
image_cast = tf.cast(image_transposed,tf.float32)
# 批处理
label_batch,image_batch= tf.train.batch([label,image_cast],batch_size=100,num_threads=1,capacity=100)
print("label_batch",label_batch)
print("image_batch",image_batch)
# 开启会话
with tf.Session() as sess:
# 开启线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
key_new,value_new,decoded_new,label_new,image_new,image_reshaped_new,image_transposed_new = sess.run([key,value,decoded,label,image,image_reshaped,image_transposed])
label_value,image_value = sess.run([label_batch,image_batch])
print("key_new:", key_new)
# print("value_new:", value_new)
print("decoded_new",decoded_new)
print("label_new",label_new)
print("image_new",image_new)
print("image_reshaped_new",image_reshaped_new)
print("image_transposed_new",image_transposed_new)
print("label_batch",label_batch)
print("image_batch",image_batch)
# 回收线程
coord.request_stop()
coord.join(threads)
return image_value,label_value
if __name__ == "__main__":
file_name = os.listdir("./cifar-10-binary/cifar-10-batches-bin")
# print("file_name:",file_name)
# 构造文件名路径列表
file_list = [os.path.join("./cifar-10-binary/cifar-10-batches-bin/",file) for file in file_name if file[-3:] == 'bin']
# print("file_list:",file_list)
# 实例化Cifar
cifar = Cifer()
image_batch,label_batch = cifar.read_and_decode(file_list)
|