0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

GRU模型实战训练 智能决策更精准

恩智浦MCU加油站 来源:恩智浦MCU加油站 2024-06-13 09:22 次阅读
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

上一期文章带大家认识了一个名为GRU的新朋友, GRU本身自带处理时序数据的属性,特别擅长对于时间序列的识别和检测(例如音频传感器信号等)。GRU其实是RNN模型的一个衍生形式,巧妙地设计了两个门控单元:reset门和更新门。reset门负责针对历史遗留的状态进行重置,丢弃掉无用信息;更新门负责对历史状态进行更新,将新的输入与历史数据集进行整合。通过模型训练,让模型能够自动调整这两个门控单元的状态,以期达到历史数据与最新数据和谐共存的目的。

理论知识掌握了,下面就来看看如何训练一个GRU模型吧。

训练平台选用Keras,请提前自行安装Keras开发工具。直接上代码,首先是数据导入部分,我们直接使用mnist手写字体数据集:

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GRU, Dense
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model


# 准备数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

模型构建与训练:

# 构建GRU模型
model = Sequential()
model.add(GRU(128, input_shape=(28, 28), stateful=False, unroll=False))
model.add(Dense(10, activation='softmax'))


# 编译模型
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])


# 模型训练
model.fit(x_train, y_train, batch_size=128, epochs=10, validation_data=(x_test, y_test))

这里,眼尖的伙伴应该是注意到了,GRU模型构建的时候,有两个参数,分别是stateful以及unroll,这两个参数是什么意思呢?

GRU层的stateful和unroll是两个重要的参数,它们对GRU模型的行为和性能有着重要影响:

stateful参数:默认情况下,stateful参数为False。当stateful设置为True时,表示在处理连续的数据时,GRU层的状态会被保留并传递到下一个时间步,而不是每个batch都重置状态。这对于处理时间序列数据时非常有用,例如在处理长序列时,可以保持模型的状态信息,而不是在每个batch之间重置。需要注意的是,在使用stateful时,您需要手动管理状态的重置。

unroll参数:默认情况下,unroll参数为False。当unroll设置为True时,表示在计算时会展开RNN的循环,这样可以提高计算性能,但会增加内存消耗。通常情况下,对于较短的序列,unroll设置为True可以提高计算速度,但对于较长的序列,可能会导致内存消耗过大。

通过合理设置stateful和unroll参数,可以根据具体的数据和模型需求来平衡模型的状态管理和计算性能。而我们这里用到的mnist数据集实际上并不是时间序列数据,而只是将其当作一个时序数据集来用。因此,每个batch之间实际上是没有显示的前后关系的,不建议使用stateful。而是每一个batch之后都要将其状态清零。即stateful=False。而unroll参数,大家就可以自行测试了。

模型评估与转换:

# 模型评估
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])


# 保存模型
model.save("mnist_gru_model.h5")


# 加载模型并转换
converter = tf.lite.TFLiteConverter.from_keras_model(load_model("mnist_gru_model.h5"))
tflite_model = converter.convert()


# 保存tflite格式模型
with open('mnist_gru_model.tflite', 'wb') as f:
    f.write(tflite_model)



便写好程序后,运行等待训练完毕,可以看到经过10个epoch之后,模型即达到了98.57%的测试精度:

44c1e04e-291f-11ef-91d2-92fbcf53809c.png

来看看最终的模型样子,参数stateful=False,unroll=True:

44e91506-291f-11ef-91d2-92fbcf53809c.png

这里,我们就会发现,模型的输入好像被拆分成了很多份,这是因为我们指定了输入是28*28。第一个28表示有28个时间步,后面的28则表示每一个时间步的维度。这里的时间步,指代的就是历史的数据。

现在,GRU模型训练就全部介绍完毕了,对于机器学习深度学习感兴趣的伙伴们,不妨亲自动手尝试一下,搭建并训练一个属于自己的GRU模型吧!

希望每一位探索者都能在机器学习的道路上不断前行,收获满满的知识和成果!

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • Gru
    Gru
    +关注

    关注

    0

    文章

    12

    浏览量

    7730
  • 机器学习
    +关注

    关注

    66

    文章

    8541

    浏览量

    136236
  • rnn
    rnn
    +关注

    关注

    0

    文章

    92

    浏览量

    7302

原文标题:GRU模型实战训练,智能决策更精准!

文章出处:【微信号:NXP_SMART_HARDWARE,微信公众号:恩智浦MCU加油站】欢迎添加关注!文章转载请注明出处。

收藏 人收藏
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

    评论

    相关推荐
    热点推荐

    在Ubuntu20.04系统中训练神经网络模型的一些经验

    本帖欲分享在Ubuntu20.04系统中训练神经网络模型的一些经验。我们采用jupyter notebook作为开发IDE,以TensorFlow2为训练框架,目标是训练一个手写数字识
    发表于 10-22 07:03

    无人驾驶:智能决策精准执行的融合

    无人驾驶核心操控技术:智能决策精准执行的融合 无人驾驶的核心操控系统是车辆实现自主驾驶的“大脑”与“四肢”,其技术核心在于通过感知、决策、执行三大模块的协同工作,替代人类驾驶员完成实
    的头像 发表于 09-19 14:03 467次阅读

    不仅管设备,还能管数据!智能系统让运维决策更精准

    智能系统在设备管理领域的应用,为企业带来了全方位的价值提升。它不仅实现了对设备的高效管理,更通过强大的数据管理能力,为运维决策提供了精准依据,帮助企业降低成本、提高生产效率、增强市场竞争力。
    的头像 发表于 09-05 10:10 599次阅读
    不仅管设备,还能管数据!<b class='flag-5'>智能</b>系统让运维<b class='flag-5'>决策</b><b class='flag-5'>更精准</b>

    make sence成的XML文件能上传到自助训练模型上吗?

    make sence成的XML文件能上传到自助训练模型上吗
    发表于 06-23 07:38

    动态感知+智能决策,一文解读 AI 场景组网下的动态智能选路技术

    人工智能(AI),特别是大规模模型训练和推理,正以前所未有的方式重塑数据中心网络。传统的“尽力而为”网络架构,在处理海量、突发的AI数据洪流时捉襟见肘。AI模型对网络性能的严苛要求——
    的头像 发表于 06-20 15:01 1321次阅读
    动态感知+<b class='flag-5'>智能</b><b class='flag-5'>决策</b>,一文解读 AI 场景组网下的动态<b class='flag-5'>智能</b>选路技术

    宇视科技梧桐大模型赋能交通治理

    迭代已实现多场景任务覆盖。近期推出的“梧桐”大模型事件检测和交通抓拍系列产品,将AI能力深度落地于交通治理场景,在算法精度、成像质量和分析性能等方面实现显著提升,赋能交通治理更精准、更智能、更高效。
    的头像 发表于 05-16 17:23 810次阅读

    请问如何在imx8mplus上部署和运行YOLOv5训练模型

    我正在从事 imx8mplus yocto 项目。我已经在自定义数据集上的 YOLOv5 上训练了对象检测模型。它在 ubuntu 电脑上运行良好。现在我想在我的 imx8mplus 板上运行该模型
    发表于 03-25 07:23

    数据标注服务—奠定大模型训练的数据基石

    数据标注是大模型训练过程中不可或缺的基础环节,其质量直接影响着模型的性能表现。在大模型训练中,数据标注承担着将原始数据转化为机器可理解、可学
    的头像 发表于 03-21 10:30 2309次阅读

    训练好的ai模型导入cubemx不成功怎么处理?

    训练好的ai模型导入cubemx不成功咋办,试了好几个模型压缩了也不行,ram占用过大,有无解决方案?
    发表于 03-11 07:18

    小白学大模型训练大语言模型的深度指南

    在当今人工智能飞速发展的时代,大型语言模型(LLMs)正以其强大的语言理解和生成能力,改变着我们的生活和工作方式。在最近的一项研究中,科学家们为了深入了解如何高效地训练大型语言模型,进
    的头像 发表于 03-03 11:51 1213次阅读
    小白学大<b class='flag-5'>模型</b>:<b class='flag-5'>训练</b>大语言<b class='flag-5'>模型</b>的深度指南

    AI赋能边缘网关:开启智能时代的新蓝海

    功耗的AI边缘计算平台;对于算法企业,要研发更轻量化、更精准的边缘AI模型;对于系统集成商,则要构建完整的边缘智能解决方案。这个万亿级的新市场,正在等待更多创新者的加入。 在这场AI与边缘计算融合的产业革命
    发表于 02-15 11:41

    腾讯公布大语言模型训练新专利

    近日,腾讯科技(深圳)有限公司公布了一项名为“大语言模型训练方法、装置、计算机设备及存储介质”的新专利。该专利的公布,标志着腾讯在大语言模型训练领域取得了新的突破。 据专利摘要显示,
    的头像 发表于 02-10 09:37 723次阅读

    【「大模型启示录」阅读体验】+开启智能时代的新钥匙

    阅读之旅。在翻开这本书之前,我对大模型的认知仅仅停留在它是一种强大的人工智能技术,可以进行自然语言处理、图像识别等任务。我知道像 ChatGPT 这样的应用是基于大模型开发的,能够与人类进行较为流畅
    发表于 12-24 13:10

    GPU是如何训练AI大模型

    在AI模型训练过程中,大量的计算工作集中在矩阵乘法、向量加法和激活函数等运算上。这些运算正是GPU所擅长的。接下来,AI部落小编带您了解GPU是如何训练AI大模型的。
    的头像 发表于 12-19 17:54 1347次阅读

    【「大模型启示录」阅读体验】如何在客服领域应用大模型

    地选择适合的模型。不同的模型具有不同的特点和优势。在客服领域,常用的模型包括循环神经网络(RNN)、长短时记忆网络(LSTM)、门控循环单元(GRU)、Transformer等,以及基
    发表于 12-17 16:53