0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

一文读懂LSTM与RNN:从原理到实战,掌握序列建模核心技术

颖脉Imgtec 2025-12-09 13:56 次阅读
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

AI 领域,文本翻译、语音识别、股价预测等场景都离不开序列数据处理。循环神经网络(RNN)作为最早的序列建模工具,开创了 “记忆历史信息” 的先河;而长短期记忆网络(LSTM)则通过创新设计,突破了 RNN 的核心局限。今天,我们从原理、梯度推导到实践,全面解析这两大经典模型。


一、基础铺垫:RNN 的核心逻辑与痛点

RNN 的核心是让模型 “记住过去”—— 通过隐藏层的循环连接,将前一时刻的信息传递到当前时刻,从而捕捉序列的时序关联。但这种 “全记忆” 设计,也埋下了梯度消失的隐患。

1.1 核心结构与参数

ca571890-d4c3-11f0-8ce9-92fbcf53809c.jpgca7a1020-d4c3-11f0-8ce9-92fbcf53809c.png
RNN 结构简化为 “输入层 - 隐藏层 - 输出层”,关键组件如下:

  • 输入:Xt(第 t 时刻输入,如文本中的词向量)
  • 隐藏状态:St(存储截至 t 时刻的历史信息,核心记忆载体)
  • 输出:Ot(第 t 时刻预测结果,如分类标签
  • 共享参数(所有时间步复用):Wx:输入→隐藏层权重矩阵(维度:隐藏层维度 × 输入维度)Ws:隐藏层→自身的循环权重矩阵(维度:隐藏层维度 × 隐藏层维度,关键)Wo:隐藏层→输出层权重矩阵(维度:输出维度 × 隐藏层维度)偏置:b₁(隐藏层偏置,维度:隐藏层维度 ×1)、b₂(输出层偏置,维度:输出维度 ×1)
  • 激活函数:隐藏层用 tanh(值缩至 [-1,1]),输出层用 Softmax(分类)或线性激活(回归)

1.2 前向传播:信息如何流动?

前向传播是 “输入→输出” 的计算过程,每个时间步的结果依赖前一时刻的隐藏状态(以下基于标量简化,向量场景逻辑一致):更新隐藏状态

ca993266-d4c3-11f0-8ce9-92fbcf53809c.png

当前记忆 St由 “当前输入 Xt” 和 “历史记忆 St-1” 共同决定,tanh 确保状态值在合理范围。计算输出

cab17e02-d4c3-11f0-8ce9-92fbcf53809c.png

输出仅依赖当前记忆St,体现 “历史信息已压缩到St中”。示例若序列长度为 3(t=1,2,3),初始状态 S₀=0(无历史信息):

cac8cd0a-d4c3-11f0-8ce9-92fbcf53809c.png

1.3 反向传播(BPTT)与梯度推导

模型训练依赖时间反向传播(BPTT):通过链式法则回溯所有时间步,计算损失对参数的梯度,再用梯度下降更新参数。假设损失函数为交叉熵损失 Loss = L (Ot, yt)(yt为 T 时刻真实标签),核心是推导 Loss 对 Wx、Ws、Wo的梯度。1.3.1 核心梯度推导步骤步骤 1:计算 Loss 对输出Ot的梯度若输出层用 Softmax 激活 + 交叉熵损失,对单个样本有:

cae46f38-d4c3-11f0-8ce9-92fbcf53809c.png

当i=j时,等于:

cafe18de-d4c3-11f0-8ce9-92fbcf53809c.png

当i≠j时,等于:

cb167ba4-d4c3-11f0-8ce9-92fbcf53809c.png

所以,softmax函数的导数可以表示为:

cb2ca046-d4c3-11f0-8ce9-92fbcf53809c.png

我们只需要将softmax层的输出pi,pj代入上面的公式就可以做求导计算了。在多分类任务中,我们通常使用交叉熵损失函数(cross-entropy loss function)来评估模型的性能。交叉熵损失函数的定义如下:

cb4cec48-d4c3-11f0-8ce9-92fbcf53809c.png

其中yj是真实标签的one−ℎot向量,pj是softmax函数的输出。交叉熵损失函数的作用是衡量模型的预测概率p和真实标签y之间的差异。交叉熵损失越小,表示模型的预测值越接近真实的标签。经验告诉我们,当使用softmax函数作为输出层激活函数时,最好使用交叉熵作为其损失函数,这是因为交叉熵和softmax函数的结合可以简化反向传播的计算。为了证明这一点,我们对交叉熵函数求导:

cb73aba8-d4c3-11f0-8ce9-92fbcf53809c.png

其中∂pj/∂zi就是上文推导的softmax的导数,将其代入式中可得:

cb8c4f00-d4c3-11f0-8ce9-92fbcf53809c.png

所以y是one-hot向量,所以:

cba7c488-d4c3-11f0-8ce9-92fbcf53809c.png

最后,化简得到的交叉熵函数的求导公式:

cbca6772-d4c3-11f0-8ce9-92fbcf53809c.png

步骤 2:计算 Loss 对隐藏状态 Sₜ的梯度隐藏状态St同时影响当前输出Ot和下一时刻隐藏状态 St+1,因此梯度需分两部分:

cc343800-d4c3-11f0-8ce9-92fbcf53809c.png

拆解导数项:

cc504914-d4c3-11f0-8ce9-92fbcf53809c.png

由St+1=tanh(WxXt+1+WsSt+b1)求导:tanh'(x)=1−tanh2(x)因此递推公式为:

cc6ebe62-d4c3-11f0-8ce9-92fbcf53809c.png

(向量场景需转置)步骤 3:计算 Loss 对参数的梯度对 Wo的梯度:

cc8ddc98-d4c3-11f0-8ce9-92fbcf53809c.png

(向量场景下为外积)对 Wx的梯度:Wₓ在所有时间步共享,需累加各时间步贡献:

ccaa8b18-d4c3-11f0-8ce9-92fbcf53809c.png

对Ws的梯度:同理,Ws的梯度为各时间步贡献的累加:

ccc1d9d0-d4c3-11f0-8ce9-92fbcf53809c.png

1.3.2 梯度消失的核心原因:累乘衰减

从Ws的梯度公式可见,远时刻(如 t=1)对梯度的贡献需经过多次 tanh'(Sk)・Ws的累乘(k 从 2 到 T):tanh'(Sk) ∈ [0,1](tanh 导数特性,最大值为 1,多数时刻小于 0.5)|Ws| < 1(为避免数值爆炸,初始化时会限制权重范围)导致累乘项随时间步指数级衰减,例如:若tanh'(Sk)=0.5,|Ws|=0.8,序列长度T=10,则累乘项 =(0.5×0.8)^9≈0.00026,远时刻梯度趋近于 0,模型无法捕捉长期依赖


突破局限:LSTM 的创新设计与梯度推导

1997 年提出的 LSTM,通过“记忆细胞 + 门控机制”实现 “选择性记忆”—— 保留重要信息、过滤噪声,从根本上缓解梯度消失。

2.1 核心结构:三门 + 记忆细胞

ccd7eb12-d4c3-11f0-8ce9-92fbcf53809c.png

LSTM 的核心是 “记忆细胞(Cₜ)” 和三个门控,分工明确(以下基于标量简化):

组件

功能

激活函数

参数(权重 + 偏置)

记忆细胞 Ct

长期记忆载体,状态平缓更新

依赖门控参数

遗忘门 ft

控制保留多少历史细胞状态 Ct-1

σ(Sigmoid,输出[0,1])

Wxf、Wℎf、bf

更新门 it

控制加入多少新信息到 Ct

σ(输出 [0,1])

Wxi、Wℎi、bi

候选记忆 gt

生成当前时刻的新候选信息

tanh(输出[-1,1])

Wxg、Wℎg、bg

输出门 ot

控制 Ct输出到隐藏状态 ht的比例

σ(输出 [0,1])

Wxo、Wℎo、bo

隐藏状态 ht

短期记忆,用于当前输出

tanh(输出[-1,1])

Wyo、bo

元素相乘

元素相加

σ 函数输出 [0,1],完美适配 “门控控制”(1 = 完全保留,0 = 完全过滤);tanh 确保信息值在合理范围

2.2 前向传播:5 步完成记忆更新

LSTM 的前向传播围绕 “记忆细胞更新” 展开,步骤清晰:遗忘门:决定 “丢什么”ft=σ(Wxf⋅Xt+Wℎf⋅ℎt−1+bf)例:ft=0.9→保留 90% 历史记忆Ct−1;ft=0.1→过滤 90% 旧信息。更新门 + 候选记忆:决定 “加什么”更新门:it=σ(Wxi⋅Xt+Wℎi⋅ht−1+bi)(控制新信息的权重)候选记忆:gt=tanh(Wxg⋅Xt+Wℎg⋅ht−1+bg)(当前时刻的新信息)更新记忆细胞:“丢旧 + 加新”Ct=Ct−1⊗ft+gt⊗it⊗为对应元素相乘,Ct同时承载 “长期历史Ct−1⊗ft” 和 “当前新信息gt⊗it”。输出门:决定 “输出什么”ot=σ(Wxo⋅Xt+Wℎo⋅ℎt−1+bo)生成隐藏状态与输出ht=ot⊗tanh (Ct)(tanh 将Ct缩至 [-1,1],再通过ot过滤)yt=Wyℎ⋅ℎt+by(最终预测结果,分类任务需加 Softmax)

2.3 反向传播与梯度推导

LSTM 的反向传播仍基于 BPTT,但需同时更新三门参数和记忆细胞相关梯度,核心是确保记忆细胞 Cₜ的梯度稳定传递。假设损失 Loss = L (yt,y't)(y't为真实标签),以下为关键梯度推导。

2.3.1 核心梯度 1:Loss 对记忆细胞 Cₜ的导数

记忆细胞Ct同时影响当前隐藏状态ℎt和下一时刻记忆细胞Ct+1,梯度公式为:

ccf0daf0-d4c3-11f0-8ce9-92fbcf53809c.png

拆解导数项:∂Loss/∂ℎt:损失对隐藏状态的梯度,由输出层反向推导:

cd08145e-d4c3-11f0-8ce9-92fbcf53809c.png

(包含当前输出和下一时刻四门的贡献)∂ℎt/∂Ct=ot⋅tanh'(Ct)(由ℎt=ot⊗tanh (Ct)求导)∂Ct+1/∂Ct=ft+1(由Ct+1=Ct⊗ft+1+gt+1⊗it+1求导)最终递推公式:

cd235aa2-d4c3-11f0-8ce9-92fbcf53809c.png

2.3.2 核心梯度 2:Loss 对门控参数的导数(以遗忘门为例)遗忘门参数(Wxf、Whf、bf)的梯度需通过链式法则推导:先求 Loss 对遗忘门输出ft的梯度:

cd367d62-d4c3-11f0-8ce9-92fbcf53809c.png

再求 Loss 对遗忘门权重 Wxf的梯度:

cd50e896-d4c3-11f0-8ce9-92fbcf53809c.png

(σ函数导数为σ(x)・(1-σ(x)),此处 ft=σ(...),故∂ft/∂Wxf⋅ft⋅(1−ft)⋅Xt)同理,Loss对Whf的梯度:

cd6e2078-d4c3-11f0-8ce9-92fbcf53809c.png

更新门、输出门、候选记忆的参数梯度推导逻辑一致,最终所有参数通过梯度下降(如 Adam 优化器)更新。2.3.3 LSTM 如何缓解梯度消失?对比 RNN 的梯度路径,LSTM 的记忆细胞梯度传递具有决定性优势:从∂Loss∂Ct的递推公式可见,当模型需要保留长期信息时,会通过参数学习使遗忘门ft+1≈1,此时:

cd886212-d4c3-11f0-8ce9-92fbcf53809c.png

由于 tanh'(Ct)∈[0,1],ot∈[0,1],但核心是∂Loss/∂Ct+1直接传递到∂Loss/∂Ct,无指数级衰减。即使序列长度达到 100+,远时刻(如 t=1)的梯度仍能稳定传递到当前时刻(如 t=100),从而有效捕捉长期依赖。

关键补充:模型如何 “学习” 让ft+1≈1?

遗忘门ft+1的输出由以下公式决定:ft+1=σ(Wxf⋅Xt+1+Wℎf⋅ℎt+bf)
其中σ是 Sigmoid 函数,当输入值>2 时,σ(x)≈0.95(接近1)。模型通过以下两种方式学习让ft+1≈1:

初始化阶段:设置遗忘门偏置 bf>0

工程实践中,会将遗忘门的偏置bf初始化为1~2(而非默认0),此时即使Wxf⋅Xt+1+Wℎf⋅ℎt=0,ft+1=σ(bf)≈0.73(已较高),为后续学习 “保留长期信息” 奠定基础。训练阶段:通过损失反向调整参数当模型因 “未保留远时刻信息” 导致 Loss 升高时,反向传播会调整Wxf、Wℎf、bf的取值:若 t=1 的信息对 t=100 的预测很重要,但当前f2=0.1(过滤了 t=1 的信息),则 Loss 会增大;反向传播时,∂Loss/∂f2为正值(增加f2可降低 Loss),进而通过∂Loss/∂Wf调整权重,使f2增大;反复迭代后,模型会学习到 “对重要的长期信息,让ft+1≈1。


三、RNN vs LSTM:怎么选?

两大模型各有优劣,需结合场景匹配:

维度

循环神经网络(RNN)

长短期记忆网络(LSTM)

记忆能力

仅短期依赖

长期依赖(序列长度 100+)

梯度问题

易出现梯度消失,远时刻信息丢失

记忆细胞梯度稳定,缓解梯度消失

模型复杂度

低(仅 3 组核心参数:Wₓ、Wₛ、Wₒ)

高(9 组核心参数:3 门 ×3 组权重 + 输出层权重)

参数数量

少(如隐藏层维度 H=128,输入维度 D=64,参数量≈128²+128×64=24576)

多(同上述维度,参数量≈4×(128²+128×64)=98304,约为 RNN 的 4 倍)

计算效率

快(前向 / 反向传播步骤少)

慢(门控计算多)

训练难度

低(参数少,收敛快,易实现)

高(参数多,易过拟合,需更多数据和正则化)

核心优势

结构简单、训练速度快、资源占用低

鲁棒性强、长期依赖捕捉能力突出、任务精度高



四、工程实践小贴士4.1 模型选择策略先简后繁:先用 RNN 验证短序列任务可行性,若精度不达标(如测试集准确率 < 85%),再替换为 LSTM;折中方案:若 LSTM 计算压力大,可选用 GRU(门控循环单元)—— 简化为重置门和更新门 2 个门,参数量比 LSTM 少 25%,性能接近 LSTM;

数据适配:若序列长度差异大(如文本长度 5-200 词),可采用 “截断 + 填充”(固定序列长度)或 “动态批处理”(同批次序列长度一致)。4.2 LSTM 性能优化技巧参数裁剪:隐藏层维度从 256 降至 128,参数量减少 75%,训练速度提升 2-3 倍;序列分段:将长序列(如 1000 帧音频)拆分为 10 个 100 帧子序列,采用 “滚动预测” 拼接结果;量化训练:将 32 位浮点数参数转为 16 位半精度,显存占用减少 50%,推理速度提升 1.5 倍;正则化:添加 Dropout(隐藏层 dropout 率 0.2-0.5)、L2 正则化(权重衰减系数 1e-4),缓解过拟合。

4.3 常见问题排查

问题现象

可能原因

解决方案

训练 loss 不下降

1. 学习率过高 / 过低2. 梯度消失(LSTM 遗忘门ft过小)

1. 调整学习率(如 Adam优化器默认 0.001,可尝试0.0001-0.01)2. 初始化遗忘门偏置bf为1-2(使ft初始值接近 1)

测试集 loss 波动大

1. 数据量不足2. 序列长度分布不均

1. 数据增强(如文本同义词替换、时序数据加噪)2. 按序列长度分组训练,平衡各长度样本占比

总结RNN 作为序列建模的 “基石”,以简单的循环结构开创了历史信息复用的思路,但受限于梯度消失无法处理长序列;LSTM 则通过记忆细胞和门控机制的创新,从梯度传递路径上解决了长期依赖问题,成为长序列任务的经典方案。尽管当前 Transformer(如 BERT、GPT)在多数序列任务中表现更优,但 RNN 和 LSTM 的核心思想(时序关联捕捉、选择性记忆)仍是理解复杂序列模型的基础,也是 AI 工程师在资源受限场景下的重要选择。你在项目中用过 RNN 或 LSTM 吗?遇到过哪些训练难题?欢迎在评论区分享你的实践经验!

本文转自:秦芯智算

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4829

    浏览量

    106807
  • rnn
    rnn
    +关注

    关注

    0

    文章

    92

    浏览量

    7301
  • LSTM
    +关注

    关注

    0

    文章

    63

    浏览量

    4296
收藏 人收藏
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

    评论

    相关推荐
    热点推荐

    蓝牙核心技术概述

    蓝牙核心技术概述():蓝牙概述蓝牙核心技术概述(二):蓝牙使用场景蓝牙核心技术概述(三): 蓝牙协议规范(射频、基带链路控制、链路管理)蓝牙核心技
    发表于 11-24 16:06

    FPGA也能做RNN

    进行了测试。该实现比嵌入在Zynq 7020 FPGA上的ARM Cortex-A9 CPU快了21倍。LSTM种特殊的RNN,由于独特的设计结构,LSTM适合于处理和预测时间
    发表于 07-31 10:11

    放弃 RNNLSTM 吧,它们真的不好用

    2014 年 RNN/LSTM 起死回生。自此,RNN/LSTM 及其变种逐渐被广大用户接受和认可。起初,LSTM
    的头像 发表于 04-25 09:43 2.1w次阅读

    循环神经网络(RNN)和(LSTM)初学者指南

    最近,有篇入门文章引发了不少关注。文章中详细介绍了循环神经网络(RNN),及其变体长短期记忆(LSTM)背后的原理。
    发表于 02-05 13:43 1297次阅读

    中国尚未掌握核心技术清单

    什么时候中国能把稀土玩到美日层次,什么时候就掌握了目前50%未掌握核心技术
    的头像 发表于 05-22 14:31 7093次阅读

    RNN以及LSTM

    循环神经网络(Recurrent Neural Network,RNN)是种用于处理序列数据的神经网络。相比般的神经网络来说,他能够处理序列
    的头像 发表于 03-15 10:44 2391次阅读

    如何理解RNNLSTM神经网络

    在深入探讨RNN(Recurrent Neural Network,循环神经网络)与LSTM(Long Short-Term Memory,长短期记忆网络)神经网络之前,我们首先需要明确它们
    的头像 发表于 07-09 11:12 1899次阅读

    LSTM神经网络的基本原理 如何实现LSTM神经网络

    LSTM(长短期记忆)神经网络是种特殊的循环神经网络(RNN),它能够学习长期依赖信息。在处理序列数据时,如时间序列分析、自然语言处理等,
    的头像 发表于 11-13 09:53 2547次阅读

    LSTM神经网络在时间序列预测中的应用

    时间序列预测是数据分析中的个重要领域,它涉及到基于历史数据预测未来值。随着深度学习技术的发展,长短期记忆(LSTM)神经网络因其在处理序列
    的头像 发表于 11-13 09:54 2689次阅读

    使用LSTM神经网络处理自然语言处理任务

    ,NLP任务的处理能力得到了显著提升。 LSTM网络简介 LSTM网络是种特殊的RNN,它通过引入门控机制来解决传统RNN在处理长
    的头像 发表于 11-13 09:56 1664次阅读

    LSTM神经网络的优缺点分析

    序列数据时的优越性能而受到广泛关注,特别是在自然语言处理(NLP)、语音识别和时间序列预测等领域。 LSTM的优点 1. 记忆能力 LSTM核心
    的头像 发表于 11-13 09:57 5737次阅读

    LSTM神经网络与传统RNN的区别

    在深度学习领域,循环神经网络(RNN)因其能够处理序列数据而受到广泛关注。然而,传统RNN在处理长序列时存在梯度消失或梯度爆炸的问题。为了解决这
    的头像 发表于 11-13 09:58 1725次阅读

    深度学习框架中的LSTM神经网络实现

    长短期记忆(LSTM)网络是种特殊的循环神经网络(RNN),能够学习长期依赖信息。与传统的RNN相比,LSTM通过引入门控机制来解决梯度消
    的头像 发表于 11-13 10:16 1539次阅读

    如何使用RNN进行时间序列预测

    种强大的替代方案,能够学习数据中的复杂模式,并进行准确的预测。 RNN的基本原理 RNN种具有循环结构的神经网络,它能够处理序列数据。
    的头像 发表于 11-15 09:45 1325次阅读

    RNNLSTM模型的比较分析

    RNN(循环神经网络)与LSTM(长短期记忆网络)模型在深度学习领域都具有处理序列数据的能力,但它们在结构、功能和应用上存在显著的差异。以下是对RNN
    的头像 发表于 11-15 10:05 2876次阅读