0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

基于TensorFlow的数据导入机制

wpl4_DeepLearni 来源:未知 作者:李倩 2018-04-02 14:44 次阅读

聊一聊TensorFlow的数据导入机制

今天我们要讲的是TensorFlow中的数据导入机制,传统的做法是习惯于先构建好TF图模型,然后开启一个会话(Session),在运行图模型之前将数据feed到图中,这种做法的缺点是数据IO带来的时间消耗很大,那么在训练非常庞大的数据集的时候,不提倡采用这种做法,TensorFlow中取而代之的是tf.data.Dataset模块,今天我们重点介绍这个。

tf.data是一个十分强大的可以用于构建复杂的数据导入机制的API,例如,如果你要处理的是图像,那么tf.data可以帮助你把分布在不同位置的文件整合到一起,并且对每幅图片添加微小的随机噪声,以及随机选取一部分图片作为一个batch进行训练;又或者是你要处理文本,那么tf.data可以帮助从文本中解析符号并且转换成embedding矩阵,然后将不同长度的序列变成一个个batch。

我们可以用tf.data.Dataset来构建一个数据集,数据集的来源可以有多种方式,例如如果你的数据集是预先以TFRecord格式写在硬盘上的,那么你可以用tf.data.TFRecordDataset来构建;如果你的数据集是内存中的tensor变量,那么可以用tf.data.Dataset.from_tensors() 或 tf.data.Dataset.from_tensor_slices()来构建。下面我将通过代码来演示它们。

首先,我们来看从内存中的tensor变量来构建数据集,如下代码所示,首先构建了一个0~10的数据集,然后构建迭代器,迭代器可以每次从数据集中提取一个元素:

import tensorflow as tf dataset=tf.data.Dataset.range(10) iterator=dataset.make_one_shot_iterator() next_element = iterator.get_next()with tf.Session() as sess: for _ in range(10): print(sess.run(next_element))

如上代码所示,range()是tf.data.Dataset类的一个静态函数,用于产生一段序列。需要注意的是,构建的数据集需要是同一种数据类型以及内部结构。除此之外,由于range(10)代表0~9一共十个数,因此,这里的iterator只能运行10次,超过以后将会抛出tf.errors.OutOfRangeError异常。如果希望不抛出异常,则可以调用dataset.repeat(count)即可实现count次自动重复的迭代器。

range的范围我们也可以在运行时才确定,即定义max_range为placeholder变量,这个时候需要调用Dataset的make_initializable_iterator方法来构建迭代器,并且这个迭代器的operation需要在迭代之前被运行,代码如下所示:

max_range=tf.placeholder(tf.int64, shape=[]) dataset = tf.data.Dataset.range(max_range) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next()with tf.Session() as sess: sess.run(iterator.initializer, feed_dict={max_range: 10}) for _ in range(10): print(sess.run(next_element))

也可以为不同的数据集创建同一个迭代器,为了使得这个迭代器可以被重复使用,需要保证不同数据集的类型和维度是一致的。例如,下面的代码演示了如何使用同一个迭代器来构建训练集和验证集,可以看到,当我们开始训练训练集的时候,就需要先执行training_init_op,目的是使得迭代器开始加载训练数据;而当进行验证的时候,则需要先执行validation_init_op,道理一样。

training_data = tf.data.Dataset.range(100).map(lambda x: x+tf.random_uniform([], -10, 10, tf.int64)) validation_data = tf.data.Dataset.range(50) iterator = tf.Iterator.from_structure(training_data.output_types, training_data.output_shapes) iterator = tf.data.Iterator.from_structure(training_data.output_types, training_data.output_shapes) next_element = iterator.get_next() training_init_op=iterator.make_initializer(training_data) validation_init_op=iterator.make_initializer(validation_data)with tf.Session() as sess: for epoch in range(10): sess.run(training_init_op) for _ in range(100): sess.run(next_element) sess.run(validation_init_op) for _ in range(50): sess.run(next_element)

也可以通过Tensor变量构建tf.data.Dataset,如下代码所示,需要注意的是,这里的Tensor的维度是4×10,因此,传入到迭代器中就是可以运行4次,每次运行生成一个长度为10的向量。

import tensorflow as tf dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10])) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next()with tf.Session() as sess: sess.run(iterator.initializer) for i in range(4): value = sess.run(next_element) print(value)

最后,还有一种比较常见的读取数据的方式,就是从TFRecord文件中去读取,这里再介绍一下之前在语音识别项目里采取的TFRecord的读写代码。

首先是将音频特征写入到TFRecord文件之中,在语音识别中,我们最常用的两个特征就是MFCC和LogFBank,要写入文件中的不仅仅是这两个变量,还要有文本标签Label以及特征序列的长度sequence_legnth,这四个变量中,只有sequence_length是整数标量,其他三个都是列表格式,所以这里对于列表使用字节来保存,而对于标量,使用整型来保存。

def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))class RecordWriter(object): def __init__(self): pass def write(self, content, tfrecords_filename): writer = tf.python_io.TFRecordWriter(tfrecords_filename) if isinstance(content, list): feature_dict = {} for i in range(len(content)): feature = content[i] if i==0: feature_raw = np.array(feature).tostring() feature_dict['mfccFeat']=_bytes_feature(feature_raw) elif i==1: feature_raw = np.array(feature).tostring() feature_dict['logfbankFeat']=_bytes_feature(feature_raw) elif i==2: feature_raw = np.array(feature).tostring() feature_dict['label']=_bytes_feature(feature_raw) else: feature_dict['sequence_length']=_int64_feature(feature) features_to_write = tf.train.Example(features=tf.train.Features(feature=feature_dict)) writer.write(features_to_write.SerializeToString()) writer.close() print('Record has been writen:'+tfrecords_filename)

写好TFRecord以后,在读取的时候首先需要对TFRecord格式文件进行解析,解析函数如下:

def parse(self, serialized): feature_dict={} feature_dict['mfccFeat']=tf.FixedLenFeature([], tf.string) feature_dict['logfbankFeat']=tf.FixedLenFeature([], tf.string) feature_dict['label']=tf.FixedLenFeature([], tf.string) feature_dict['sequence_length']=tf.FixedLenFeature([1], tf.int64) features = tf.parse_single_example( serialized, features=feature_dict) mfcc = tf.reshape(tf.decode_raw(features['mfccFeat'], tf.float32), [-1, self.feature_num]) logfbank = tf.reshape(tf.decode_raw(features['logfbankFeat'], tf.float32), [-1, self.feature_num]) label = tf.decode_raw(features['label'], tf.int64) return mfcc, logfbank, label, features['sequence_length']

然后我们可以直接通过调用tf.data.TFRecordDataset来导入TFRecord文件列表,以及对每个文件调用parse函数进行解析,并且由于每个文件的特征矩阵长度不一,所以需要对齐进行padding操作,最终可以获得迭代器,代码如下:

self.fileNameList = tf.placeholder(tf.string, [None, ]) padded_shapes= ([-1,feature_num],[-1,feature_num],[-1],[1]) padded_values = (0.0,0.0,np.int64(-1),np.int64(0)) dataset = tf.data.TFRecordDataset(self.fileNameList, buffer_size=self.buffer_size).map(self.parse, num_parallel_call).padded_batch(batch_size, padded_shapes, padded_values) self.iterator = tf.data.Iterator.from_structure((tf.float32, tf.float32, tf.int64, tf.int64), (tf.TensorShape([None, None, 60]), tf.TensorShape([None, None, 60]), tf.TensorShape([None, None]), tf.TensorShape([None, None]))) self.initializer = self.iterator.make_initializer(dataset)

于是,关于TFRecord文件的读写就介绍完了,并且,基于TensorFlow的数据导入机制也介绍完了。

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 数据集
    +关注

    关注

    4

    文章

    1178

    浏览量

    24349
  • tensorflow
    +关注

    关注

    13

    文章

    313

    浏览量

    60242

原文标题:聊一聊TensorFlow的数据导入机制

文章出处:【微信号:DeepLearningDigest,微信公众号:深度学习每日摘要】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    关于 TensorFlow

    在节点间相互联系的多维数据数组,即张量(tensor)。它灵活的架构让你可以在多种平台上展开计算,例如台式计算机中的一个或多个CPU(或GPU),服务器,移动设备等等。TensorFlow 最初由
    发表于 03-30 19:57

    使用 TensorFlow, 你必须明白 TensorFlow

    数据, 计算图中, 操作间传递的数据都是 tensor. 你可以把 TensorFlow tensor 看作是一个 n 维的数组或列表. 一个 tensor 包含一个静态类型 rank, 和 一个
    发表于 03-30 20:03

    TensorFlow运行时无法加载本机

    )Python分发由英特尔公司提供给您。请查看:https://software.intel.com/en-us/python-distribution>>>将tensorflow导入为tf回溯
    发表于 10-19 12:00

    导入tensorflow时未找到“GLIBC_2.23”错误

    导入tensorflow时,它给我一个错误。附上错误的屏幕截图。请帮忙。GLIBC_Error.PNG 40.9 K.以上来自于谷歌翻译以下为原文Hi, I created a new conda
    发表于 11-14 09:59

    情地使用Tensorflow吧!

    在节点间相互联系的多维数据数组,即张量(tensor)。它灵活的架构让你可以在多种平台上展开计算,例如台式计算机中的一个或多个CPU(或GPU),服务器,移动设备等等。TensorFlow 最初由
    发表于 07-22 10:13

    TensorFlow是什么

    TensorFlow 在深度学习模型中的应用,使读者可以轻松地将模型用于数据集并开发有用的应用程序。每章包含一系列处理技术问题、依赖性、代码和解读的示例,在每章的最后,还有一个功能完善的深度学习模型。
    发表于 07-22 10:14

    TensorFlow教程|常见问题

    ,参看 TensorFlow使用 GPU ; 使用多 GPU 的示范实例参看 CIFAR-10 教程 。可用的 tensor 有哪些不同的类型?TensorFlow 支持许多种不同的数据
    发表于 07-27 18:33

    TensorFlow csv文件读取数据(代码实现)详解

    大多数人了解 Pandas 及其在处理大数据文件方面的实用性。TensorFlow 提供了读取这种文件的方法。前面章节中,介绍了如何在 TensorFlow 中读取文件,本节将重点介绍如何从 CSV
    发表于 07-28 14:40

    TensorFlow实现简单线性回归

    本小节直接从 TensorFlow contrib 数据集加载数据。使用随机梯度下降优化器优化单个训练样本的系数。实现简单线性回归的具体做法导入需要的所有软件包: 在神经网络中,所有的
    发表于 08-11 19:34

    TensorFlow实现多元线性回归(超详细)

    TensorFlow 实现简单线性回归的基础上,可通过在权重和占位符的声明中稍作修改来对相同的数据进行多元线性回归。在多元线性回归的情况下,由于每个特征具有不同的值范围,归一化变得至关重要
    发表于 08-11 19:35

    TensorFlow逻辑回归处理MNIST数据

    [0000000010]:具体做法导入所需的模块: 可以从模块 input_data 给出的 TensorFlow 示例中获取 MNIST 的输入数据。该 one_hot 标志设置为真,以使用标签
    发表于 08-11 19:36

    TensorFlow逻辑回归处理MNIST数据

    [0000000010]:具体做法导入所需的模块: 可以从模块 input_data 给出的 TensorFlow 示例中获取 MNIST 的输入数据。该 one_hot 标志设置为真,以使用标签
    发表于 08-11 19:36

    如何用TensorFlow导入MNIST数据集?

    TensorFlow导入MNIST数据
    发表于 11-11 07:33

    图文详解tensorflow数据读取机制

    tensorflow数据读取机制,文章的最后还会给出实战代码以供参考。 授权转载:知乎专栏 AI Insight 一、tensorflow读取机制
    发表于 09-22 16:41 1次下载

    TensorFlow数据读取机制分析

    解释一下TensorFlow数据读取机制,文章的最后还会给出实战代码以供参考。 TensorFlow读取机制图解 首先需要思考的一个问题是
    发表于 09-28 17:45 0次下载
    <b class='flag-5'>TensorFlow</b><b class='flag-5'>数据</b>读取<b class='flag-5'>机制</b>分析