0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

构建简单数据管道,为什么tf.data要比feed_dict更好?

zhKF_jqr_AI 来源:lq 2018-12-03 09:08 次阅读

在大多数面向初学者的TensorFlow教程里,作者通常会建议读者在会话中用feed_dict为模型导入数据——feed_dict是一个字典,能为占位符馈送数据。但是,其实TF提供了另一种更好的、更简单的方法:只需使用tf.dataAPI,你就能用几行代码搞定高性能数据管道。

那么tf.data的优势具体在哪里呢?如下图所示,虽然feed_dict的灵活性大家有目共睹,但每当我们需要等待CPU把数据馈送进来时,GPU就一直处于闲置状态,也就是程序运行效率太低。

而tf.data管道没有这个问题,它能提前抓取下个batch的数据,降低总体闲置时间。在这个基础上,如果我们采用并行数据导入,或者事先进行数据预处理,整个过程就更快了。

在5分钟内实现小型图像管道

要构建一个简单数据管道,首先我们需要两个对象:一个用于存储数据集的tf.data.Dataset,以及一个允许我们逐个从数据集中提取样本的tf.data.Iterator。

对于tf.data.Dataset,它在图像管道中是这样的:

[

[Tensor(image), Tensor(label)],

[Tensor(image), Tensor(label)],

...

]

之后我们就可以用tf.data.Iterator逐个检索图像标签对。在实践中,多个图像标签对通常会组成元素序列,方便迭代器进行提取。

至于数据集,DatasetAPI有两种创建数据集的方法,其一是从源(如Python中的文件名列表)创建数据集,其二是可以直接在现有数据集上应用转换,下面是一些示例:

Dataset(list of image files) → Dataset(actual images)

Dataset(6400 images) → Dataset(64 batches with 100 images each)

Dataset(list of audio files) → Dataset(shuffled list of audio files)

定义计算图

小型图像管道的大致情况如下图所示:

所有代码都和模型、损失、优化器等一起放在我们的计算图定义中。首先,我们要从文件列表中创建一个张量。

# define list of files

files = ['a.png', 'b.png', 'c.png', 'd.png']

# create a dataset from filenames

dataset = tf.data.Dataset.from_tensor_slices(files)

之后是定义一个函数来从其路径加载图像(作为张量),并调用tf.data.Dataset.map()把函数用于数据集中的所有元素(文件路径)。如果想并行调用函数,你也可以设置num_parallel_calls=n里的map()参数

# Source

def load_image(path):

image_string = tf.read_file(path)

# Don't use tf.image.decode_image, or the output shape will be undefined

image = tf.image.decode_jpeg(image_string, channels=3)

# This will convert to float values in [0, 1]

image = tf.image.convert_image_dtype(image, tf.float32)

image = tf.image.resize_images(image, [image_size, image_size])

return image

# Apply the function load_image to each filename in the dataset

dataset = dataset.map(load_image, num_parallel_calls=8)

然后是用tf.data.Dataset.batch()创建batch:

# Create batches of 64 images each

dataset = dataset.batch(64)

如果想减少GPU闲置时间,我们可以在管道末尾添加tf.data.Dataset.prefetch(buffer_size),其中buffer_size这个参数表示预抓取的batch数,我们一般设buffer_size=1,但在某些情况下,尤其是处理每个batch耗时不同时,我们也可以适当扩大一点。

dataset = dataset.prefetch(buffer_size=1)

最后,我们再创建一个迭代器遍历数据集。虽然迭代器的选择有很多,但对于大多数任务,我们还是建议选择可以初始化的迭代器。

iterator = dataset.make_initializable_iterator()

调用tf.data.Iterator.get_next()创建占位符张量,每次评估时,TensorFlow都会填充下一batch的图像。

batch_of_images = iterator.get_next()

如果写到这里,你突然想换回feed_dict的方法,你可以用batch_of_images把之前的占位符全都替换掉。

运行会话

现在,我们就可以向往常一样运行模型了。但在每个epoch前,记得先评估iterator.initializer的op和tf.errors.OutOfRangeError有没有抛出异常。

with tf.Session() as session:

for i in range(epochs):

session.run(iterator.initializer)

try:

# Go through the entire dataset

whileTrue:

image_batch = session.run(batch_of_images)

except tf.errors.OutOfRangeError:

print('End of Epoch.')

nvidia-smi这个命令可以帮我们监控GPU利用率,找到数据管道中的瓶颈。正常情况下,GPU的平均利用率应该高于70%-80%。

更完整的数据管道

Shuffle

在Dataset里,tf.data.Dataset.shuffle()是一个比较常用的方法,它可以用来打乱数据集中的数据顺序。它的参数buffer_size指定的是一次打乱的元素数量,一般情况下,我们建议把这个参数值设大一点,最好一次性就能把整个数据集洗牌,因为如果参数过小,它可能会造成意料之外的偏差。

dataset = tf.data.Dataset.from_tensor_slices(files)

dataset = dataset.shuffle(len(files))

数据增强

数据增强是扩大数据集的一种常用方式,这方面常用的函数有tf.image.random_flip_left_right()、tf.image.random_brightness()和tf.image.random_saturation():

# Source

def train_preprocess(image):

image = tf.image.random_flip_left_right(image)

image = tf.image.random_brightness(image, max_delta=32.0 / 255.0)

image = tf.image.random_saturation(image, lower=0.5, upper=1.5)

# Make sure the image is still in [0, 1]

image = tf.clip_by_value(image, 0.0, 1.0)

return image

标签

要想在图像上加载标签(或其他元数据),我们只需在创建初始数据集时就把它们包含在内:

# files is a python list of image filenames

# labels is a numpy array with label data for each image

dataset = tf.data.Dataset.from_tensor_slices((files, labels))

确保应用于数据集的所有.map()函数都允许标签数据通过:

def load_image(path, label):

# load image

return image, label

dataset = dataset.map(load_image)

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • gpu
    gpu
    +关注

    关注

    27

    文章

    4403

    浏览量

    126563
  • 数据集
    +关注

    关注

    4

    文章

    1176

    浏览量

    24340

原文标题:构建简单数据管道,为什么tf.data要比feed_dict更好?

文章出处:【微信号:jqr_AI,微信公众号:论智】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    带校时功能的简单数字钟

    用555制成振荡器,74LS90制成分频器,带校时功能的简单数字钟。
    发表于 02-01 18:35

    简单数据采集系统怎样清除上次采集的数据

    各位前辈:我用labview编了一个简单数据采集系统,现在想在暂停后,再次运行时,清除掉上次采集的数据,需要怎样做啊??
    发表于 03-07 21:29

    使用 TensorFlow, 你必须明白 TensorFlow

    ([output], feed_dict={input1:[7.], input2:[2.]})# 输出:# [array([ 14.], dtype=float32)]for a
    发表于 03-30 20:03

    Tensorflow快餐教程(1) - 30行代码搞定手写识别

    ) == sess.run(predict_op, feed_dict={X: teX})))经过100轮的训练,我们的准确率是92.36%。无脑的浅层神经网络用了最简单的线性模型,我们换成经典的神经网络来实现这个功能
    发表于 04-28 16:08

    TFTF定义两个变量相乘之placeholder先hold类似变量+feed_dict最后外界传入值

    TFTF定义两个变量相乘之placeholder先hold类似变量+feed_dict最后外界传入值
    发表于 12-21 10:35

    内建属性__dict__的简单概述

    Python中内建属性__dict__
    发表于 07-24 10:16

    简单数字钟电路图

    简单数字钟电路图
    发表于 01-08 11:11 179次下载

    简单数字钟电路图

    简单数字钟电路图
    发表于 01-08 11:11 103次下载

    如何将自定义图片输入到TensorFlow的训练模型

    对于上述代码中与模型构建相关的代码,请查阅官方《Deep MNIST for Experts》一节的内容进行理解。在本文中,需要重点掌握的是如何将本地图片源整合成为feed_dict可接受的格式。其中最关键的是这两行
    的头像 发表于 08-17 15:57 8702次阅读

    使用tensorflow构建一个简单神经网络

    给大家分享一个案例,如何使用tensorflow 构建一个简单神经网络。首先我们需要创建我们的样本,由于是监督学习,所以还是需要label的。为了简单起见,我们只创建一个样本进行训练, 可以当做
    的头像 发表于 10-16 08:41 2117次阅读

    tf.data API的功能和最佳实践操作

    tf.data API 通过 tf.data.Dataset.prefetch 转换提供了一个软件 pipelining 操作机制,该转换可用于将数据生成的时间与所消耗时间分离。特别是,转换使用后
    的头像 发表于 01-11 13:51 1.3w次阅读

    TensorFlow 2.0将专注于简单性和易用性

    使用 tf.data 加载数据。使用输入管道读取训练数据,用 tf.data 创建的输入线程读取训练数据
    的头像 发表于 01-18 10:44 2424次阅读
    TensorFlow 2.0将专注于<b class='flag-5'>简单</b>性和易用性

    python教程之变量和简单数据类型

    本文档的主要内容详细介绍的是python教程之变量和简单数据类型。
    发表于 04-26 08:00 7次下载
    python教程之变量和<b class='flag-5'>简单数据</b>类型

    C语言简单数据解析

    C语言简单数据解析​ 在嵌入式开发中通过串口等传输数据通常使用JSON解析,虽然JSON十分强大,但JSON耗费资源太多,数据的打包和解析都比较麻烦。有时我们只是传输一些简单
    发表于 01-13 15:17 8次下载
    C语言<b class='flag-5'>简单数据</b>解析

    使用tf.data进行数据集处理

    在进行AI模型训练过程前,需要对数据集进行处理, Tensorflow提供了tf.data数据集处理模块,通过该接口能够轻松实现数据集预处理。tf.
    的头像 发表于 11-29 15:34 870次阅读