0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

GRU模型实战训练 智能决策更精准

恩智浦MCU加油站 来源:恩智浦MCU加油站 2024-06-13 09:22 次阅读
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

上一期文章带大家认识了一个名为GRU的新朋友, GRU本身自带处理时序数据的属性,特别擅长对于时间序列的识别和检测(例如音频传感器信号等)。GRU其实是RNN模型的一个衍生形式,巧妙地设计了两个门控单元:reset门和更新门。reset门负责针对历史遗留的状态进行重置,丢弃掉无用信息;更新门负责对历史状态进行更新,将新的输入与历史数据集进行整合。通过模型训练,让模型能够自动调整这两个门控单元的状态,以期达到历史数据与最新数据和谐共存的目的。

理论知识掌握了,下面就来看看如何训练一个GRU模型吧。

训练平台选用Keras,请提前自行安装Keras开发工具。直接上代码,首先是数据导入部分,我们直接使用mnist手写字体数据集:

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GRU, Dense
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model


# 准备数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

模型构建与训练:

# 构建GRU模型
model = Sequential()
model.add(GRU(128, input_shape=(28, 28), stateful=False, unroll=False))
model.add(Dense(10, activation='softmax'))


# 编译模型
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])


# 模型训练
model.fit(x_train, y_train, batch_size=128, epochs=10, validation_data=(x_test, y_test))

这里,眼尖的伙伴应该是注意到了,GRU模型构建的时候,有两个参数,分别是stateful以及unroll,这两个参数是什么意思呢?

GRU层的stateful和unroll是两个重要的参数,它们对GRU模型的行为和性能有着重要影响:

stateful参数:默认情况下,stateful参数为False。当stateful设置为True时,表示在处理连续的数据时,GRU层的状态会被保留并传递到下一个时间步,而不是每个batch都重置状态。这对于处理时间序列数据时非常有用,例如在处理长序列时,可以保持模型的状态信息,而不是在每个batch之间重置。需要注意的是,在使用stateful时,您需要手动管理状态的重置。

unroll参数:默认情况下,unroll参数为False。当unroll设置为True时,表示在计算时会展开RNN的循环,这样可以提高计算性能,但会增加内存消耗。通常情况下,对于较短的序列,unroll设置为True可以提高计算速度,但对于较长的序列,可能会导致内存消耗过大。

通过合理设置stateful和unroll参数,可以根据具体的数据和模型需求来平衡模型的状态管理和计算性能。而我们这里用到的mnist数据集实际上并不是时间序列数据,而只是将其当作一个时序数据集来用。因此,每个batch之间实际上是没有显示的前后关系的,不建议使用stateful。而是每一个batch之后都要将其状态清零。即stateful=False。而unroll参数,大家就可以自行测试了。

模型评估与转换:

# 模型评估
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])


# 保存模型
model.save("mnist_gru_model.h5")


# 加载模型并转换
converter = tf.lite.TFLiteConverter.from_keras_model(load_model("mnist_gru_model.h5"))
tflite_model = converter.convert()


# 保存tflite格式模型
with open('mnist_gru_model.tflite', 'wb') as f:
    f.write(tflite_model)



便写好程序后,运行等待训练完毕,可以看到经过10个epoch之后,模型即达到了98.57%的测试精度:

44c1e04e-291f-11ef-91d2-92fbcf53809c.png

来看看最终的模型样子,参数stateful=False,unroll=True:

44e91506-291f-11ef-91d2-92fbcf53809c.png

这里,我们就会发现,模型的输入好像被拆分成了很多份,这是因为我们指定了输入是28*28。第一个28表示有28个时间步,后面的28则表示每一个时间步的维度。这里的时间步,指代的就是历史的数据。

现在,GRU模型训练就全部介绍完毕了,对于机器学习深度学习感兴趣的伙伴们,不妨亲自动手尝试一下,搭建并训练一个属于自己的GRU模型吧!

希望每一位探索者都能在机器学习的道路上不断前行,收获满满的知识和成果!

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • Gru
    Gru
    +关注

    关注

    0

    文章

    13

    浏览量

    7759
  • 机器学习
    +关注

    关注

    67

    文章

    8561

    浏览量

    137208
  • rnn
    rnn
    +关注

    关注

    0

    文章

    92

    浏览量

    7373

原文标题:GRU模型实战训练,智能决策更精准!

文章出处:【微信号:NXP_SMART_HARDWARE,微信公众号:恩智浦MCU加油站】欢迎添加关注!文章转载请注明出处。

收藏 人收藏
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

    评论

    相关推荐
    热点推荐

    AI大模型微调企业项目实战

    自主可控大模型:企业微调实战课,筑牢未来 AI 底座 在人工智能席卷全球商业版图的今天,企业对大模型(LLM)的态度已经从“新奇观望”转变为“全面拥抱”。然而,随着应用层面的不断深入
    发表于 04-16 18:48

    AI落地培训 | FH8626V300L 人形检测模型嵌入式部署全链路实战

    举办“AI模型训练流程与部署实战”免费培训!本次培训以富瀚微电子FH8626V300L为硬件平台——面向智能高清网络摄像机应用的高性能SoC,集成高性能ISP和H
    的头像 发表于 04-15 18:12 103次阅读
    AI落地培训 | FH8626V300L 人形检测<b class='flag-5'>模型</b>嵌入式部署全链路<b class='flag-5'>实战</b>

    人工智能多模态与视觉大模型开发实战 - 2026必会

    训练模型可以逐渐提升对图像的理解能力,实现对各种视觉任务的精准处理。 此外,视觉大模型的发展还得益于大规模数据集和强大计算资源的支持。海量标注数据为
    发表于 04-15 16:06

    九天菜菜大模型agent智能体开发实战2026一月班

    提供更精准的理财建议。在医疗领域,Agent 可以辅助医生进行疾病诊断,快速分析大量医学文献和病例数据,为治疗方案提供参考,提高诊断效率和准确性。 此次大模型 Agent 开发实战课程的火爆开课,正是
    发表于 04-15 16:04

    AI落地培训 | 人形检测模型嵌入式部署全链路实战

    你是否想系统了解AI落地全链路,却缺少一个完整的实战项目练手?模型部署环节繁多,缺乏一套清晰的实战路径?4月18日、4月25日、5月16日RT-Thread将分别在苏州、成都、南京举办“AI
    的头像 发表于 04-10 18:41 142次阅读
    AI落地培训 | 人形检测<b class='flag-5'>模型</b>嵌入式部署全链路<b class='flag-5'>实战</b>

    AI模型训练与部署实战 | 线下免费培训

    你是否想系统了解AI落地全链路,但缺少一个完整的实战项目练手?模型部署环节繁多,缺乏一套清晰的实战路径?4月18日、4月25日、5月16日RT-Thread将分别在苏州、成都、南京举办“AI
    的头像 发表于 04-07 13:08 608次阅读
    AI<b class='flag-5'>模型</b><b class='flag-5'>训练</b>与部署<b class='flag-5'>实战</b> | 线下免费培训

    【2025夏季班正课】大模型Agent智能体开发实战 课分享

    【2025年12月班】大模型与Agent智能体开发实战] 拒绝碎片化:体系化学 Agent 开发方法的技术深度剖析 在当今的人工智能应用开发领域,一种浮躁的“碎片化”风气正在蔓延。许多
    发表于 03-29 16:12

    如何训练自己的AI模型——RT-Thread×富瀚微FH8626V300L模型训练部署教程 | 技术集结

    面对消费电子中纷繁的智能检测需求,如何让算法持续进化?富瀚微最新发布的FH86X6V300芯片AI训练教程,以FH8626V300L为硬件核心,手把手带您走通从模型训练到端侧部署的完整
    的头像 发表于 02-09 11:51 596次阅读
    如何<b class='flag-5'>训练</b>自己的AI<b class='flag-5'>模型</b>——RT-Thread×富瀚微FH8626V300L<b class='flag-5'>模型</b><b class='flag-5'>训练</b>部署教程 | 技术集结

    五大卫星运管中心大模型智能决策分系统软件的应用与未来发展

        五大机构/企业卫星运管中心大模型智能决策分系统实践综述    当前,随着大规模星座部署与智能化作战需求激增,以大模型驱动的卫星
    的头像 发表于 12-18 14:58 488次阅读

    模型赋能物资需求精准预测与采购系统软件平台

        北京五木恒润大模型赋能物资需求精准预测与采购平台系统软件,深度融合多源数据与智能算法,大幅提升需求预测准确性与采购决策科学性,成为企业优化供应链管理、降低运营成本的核心工具。以
    的头像 发表于 12-17 16:37 372次阅读

    在Ubuntu20.04系统中训练神经网络模型的一些经验

    本帖欲分享在Ubuntu20.04系统中训练神经网络模型的一些经验。我们采用jupyter notebook作为开发IDE,以TensorFlow2为训练框架,目标是训练一个手写数字识
    发表于 10-22 07:03

    无人驾驶:智能决策精准执行的融合

    无人驾驶核心操控技术:智能决策精准执行的融合 无人驾驶的核心操控系统是车辆实现自主驾驶的“大脑”与“四肢”,其技术核心在于通过感知、决策、执行三大模块的协同工作,替代人类驾驶员完成实
    的头像 发表于 09-19 14:03 878次阅读

    不仅管设备,还能管数据!智能系统让运维决策更精准

    智能系统在设备管理领域的应用,为企业带来了全方位的价值提升。它不仅实现了对设备的高效管理,更通过强大的数据管理能力,为运维决策提供了精准依据,帮助企业降低成本、提高生产效率、增强市场竞争力。
    的头像 发表于 09-05 10:10 873次阅读
    不仅管设备,还能管数据!<b class='flag-5'>智能</b>系统让运维<b class='flag-5'>决策</b><b class='flag-5'>更精准</b>

    动态感知+智能决策,一文解读 AI 场景组网下的动态智能选路技术

    人工智能(AI),特别是大规模模型训练和推理,正以前所未有的方式重塑数据中心网络。传统的“尽力而为”网络架构,在处理海量、突发的AI数据洪流时捉襟见肘。AI模型对网络性能的严苛要求——
    的头像 发表于 06-20 15:01 1698次阅读
    动态感知+<b class='flag-5'>智能</b><b class='flag-5'>决策</b>,一文解读 AI 场景组网下的动态<b class='flag-5'>智能</b>选路技术

    宇视科技梧桐大模型赋能交通治理

    迭代已实现多场景任务覆盖。近期推出的“梧桐”大模型事件检测和交通抓拍系列产品,将AI能力深度落地于交通治理场景,在算法精度、成像质量和分析性能等方面实现显著提升,赋能交通治理更精准、更智能、更高效。
    的头像 发表于 05-16 17:23 1238次阅读