- 使用Tensorflow的tf.data.Dataset总结
1、在使用tf.data建立Dataset时有两种方式: 使用本地电脑内存上的数据: tf.data.Dataset.from_tensors() tf.data.Dataset.from_tensor_slices() 需要的是列表或者其他作为输入 使用TFRecord 格式数据: tf.data.TFRecordDataset() 得到的dataset是python的一个可迭代对象 2、tf.data.Dataset.from_generator 用于将Python的生成器转变为Dataset 例子:这其中必须要有输出类型(output_types)和输出形状(output_shapes)。因为tf.data会建造一个tf.Graph,而图边缘必须要有output_typestf.dtype。
def count(stop):
i = 0
while i<stop:
yield i
i += 1
for n in count(5):
print(n)
>>>0
>>>1
>>>2
>>>3
>>>4
ds_counter = tf.data.Dataset.from_generator(count, args=[25],
output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
print(count_batch.numpy())
3、利用 tf.image.ImageDataGenerator 进行数据增强 : 首先创建一个ImageDataGenerator
img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
然后使用从路径读取图片不断产生数据,形成的生成器用法与ImageDataGenerator类的方法一致。比如使用flow_from_directory()
images, labels = next(img_gen.flow_from_directory(flowers))
>>>Found 3670 images belonging to 5 classes.
最后调用from_genrtor()形成Dataset:
ds = tf.data.Dataset.from_generator(
lambda: img_gen.flow_from_directory(flowers),
output_types=(tf.float32, tf.float32),
output_shapes=([32,256,256,3], [32,5])
)
ds.element_spec
>>>(TensorSpec(shape=(32, 256, 256, 3), dtype=tf.float32, name=None),
TensorSpec(shape=(32, 5), dtype=tf.float32, name=None))
4、使用Dataset.map(f)进行数据预处理 重点理解:Dataset.map(f) 这个函数通过对输入数据集的每个元素应用给定的函数来生成一个新的数据集。它是基于python中的map()函数的,但一定要注意的是与python中的有很大不同。 一定注意的是这个函数的输入要求是tf.Tensor ,返回的输出也是张量对象tf.Tensor 。这里很容易出错,经常直接利用array来作为输入。 最重要的是!!!: 这个函数他是对整个Dataset进行映射,相当于无论原始的Dataset有多少个张量在其中都是直接一对一直接应用函数转换为另一个Dataset。所以映射函数f 一般都是TensorFlow的操作,就是一般都是tf.XX 这样的函数。 它的实现是使用标准的TensorFlow操作将一个元素转换为另一个元素。 例子:
list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))
def parse_image(filename):
parts = tf.strings.split(filename, os.sep)
label = parts[-2]
image = tf.io.read_file(filename)
image = tf.io.decode_jpeg(image)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [128, 128])
return image, label
images_ds = list_ds.map(parse_image)
for image, label in images_ds.take(2):
show(image, label)
参考:tf.data: Build TensorFlow input pipelines
|