在 AI 领域,文本翻译、语音识别、股价预测等场景都离不开序列数据处理。循环神经网络(RNN)作为最早的序列建模工具,开创了 “记忆历史信息” 的先河;而长短期记忆网络(LSTM)则通过创新设计,突破了 RNN 的核心局限。今天,我们从原理、梯度推导到实践,全面解析这两大经典模型。
一、基础铺垫:RNN 的核心逻辑与痛点
RNN 的核心是让模型 “记住过去”—— 通过隐藏层的循环连接,将前一时刻的信息传递到当前时刻,从而捕捉序列的时序关联。但这种 “全记忆” 设计,也埋下了梯度消失的隐患。
1.1 核心结构与参数

RNN 结构简化为 “输入层 - 隐藏层 - 输出层”,关键组件如下:
- 输入:Xt(第 t 时刻输入,如文本中的词向量)
- 隐藏状态:St(存储截至 t 时刻的历史信息,核心记忆载体)
- 输出:Ot(第 t 时刻预测结果,如分类标签)
- 共享参数(所有时间步复用):Wx:输入→隐藏层权重矩阵(维度:隐藏层维度 × 输入维度)Ws:隐藏层→自身的循环权重矩阵(维度:隐藏层维度 × 隐藏层维度,关键)Wo:隐藏层→输出层权重矩阵(维度:输出维度 × 隐藏层维度)偏置:b₁(隐藏层偏置,维度:隐藏层维度 ×1)、b₂(输出层偏置,维度:输出维度 ×1)
- 激活函数:隐藏层用 tanh(值缩至 [-1,1]),输出层用 Softmax(分类)或线性激活(回归)
1.2 前向传播:信息如何流动?
前向传播是 “输入→输出” 的计算过程,每个时间步的结果依赖前一时刻的隐藏状态(以下基于标量简化,向量场景逻辑一致):更新隐藏状态

当前记忆 St由 “当前输入 Xt” 和 “历史记忆 St-1” 共同决定,tanh 确保状态值在合理范围。计算输出

输出仅依赖当前记忆St,体现 “历史信息已压缩到St中”。示例:若序列长度为 3(t=1,2,3),初始状态 S₀=0(无历史信息):

1.3 反向传播(BPTT)与梯度推导
模型训练依赖时间反向传播(BPTT):通过链式法则回溯所有时间步,计算损失对参数的梯度,再用梯度下降更新参数。假设损失函数为交叉熵损失 Loss = L (Ot, yt)(yt为 T 时刻真实标签),核心是推导 Loss 对 Wx、Ws、Wo的梯度。1.3.1 核心梯度推导步骤步骤 1:计算 Loss 对输出Ot的梯度若输出层用 Softmax 激活 + 交叉熵损失,对单个样本有:

当i=j时,等于:

当i≠j时,等于:

所以,softmax函数的导数可以表示为:

我们只需要将softmax层的输出pi,pj代入上面的公式就可以做求导计算了。在多分类任务中,我们通常使用交叉熵损失函数(cross-entropy loss function)来评估模型的性能。交叉熵损失函数的定义如下:

其中yj是真实标签的one−ℎot向量,pj是softmax函数的输出。交叉熵损失函数的作用是衡量模型的预测概率p和真实标签y之间的差异。交叉熵损失越小,表示模型的预测值越接近真实的标签。经验告诉我们,当使用softmax函数作为输出层激活函数时,最好使用交叉熵作为其损失函数,这是因为交叉熵和softmax函数的结合可以简化反向传播的计算。为了证明这一点,我们对交叉熵函数求导:

其中∂pj/∂zi就是上文推导的softmax的导数,将其代入式中可得:

所以y是one-hot向量,所以:

最后,化简得到的交叉熵函数的求导公式:

步骤 2:计算 Loss 对隐藏状态 Sₜ的梯度隐藏状态St同时影响当前输出Ot和下一时刻隐藏状态 St+1,因此梯度需分两部分:

拆解导数项:

由St+1=tanh(WxXt+1+WsSt+b1)求导:tanh'(x)=1−tanh2(x)因此递推公式为:

(向量场景需转置)步骤 3:计算 Loss 对参数的梯度对 Wo的梯度:

(向量场景下为外积)对 Wx的梯度:Wₓ在所有时间步共享,需累加各时间步贡献:

对Ws的梯度:同理,Ws的梯度为各时间步贡献的累加:

1.3.2 梯度消失的核心原因:累乘衰减
从Ws的梯度公式可见,远时刻(如 t=1)对梯度的贡献需经过多次 tanh'(Sk)・Ws的累乘(k 从 2 到 T):tanh'(Sk) ∈ [0,1](tanh 导数特性,最大值为 1,多数时刻小于 0.5)|Ws| < 1(为避免数值爆炸,初始化时会限制权重范围)导致累乘项随时间步指数级衰减,例如:若tanh'(Sk)=0.5,|Ws|=0.8,序列长度T=10,则累乘项 =(0.5×0.8)^9≈0.00026,远时刻梯度趋近于 0,模型无法捕捉长期依赖。
突破局限:LSTM 的创新设计与梯度推导
1997 年提出的 LSTM,通过“记忆细胞 + 门控机制”实现 “选择性记忆”—— 保留重要信息、过滤噪声,从根本上缓解梯度消失。
2.1 核心结构:三门 + 记忆细胞

LSTM 的核心是 “记忆细胞(Cₜ)” 和三个门控,分工明确(以下基于标量简化):
组件 | 功能 | 激活函数 | 参数(权重 + 偏置) |
记忆细胞 Ct | 长期记忆载体,状态平缓更新 | 无 | 依赖门控参数 |
遗忘门 ft | 控制保留多少历史细胞状态 Ct-1 | σ(Sigmoid,输出[0,1]) | Wxf、Wℎf、bf |
更新门 it | 控制加入多少新信息到 Ct | σ(输出 [0,1]) | Wxi、Wℎi、bi |
候选记忆 gt | 生成当前时刻的新候选信息 | tanh(输出[-1,1]) | Wxg、Wℎg、bg |
输出门 ot | 控制 Ct输出到隐藏状态 ht的比例 | σ(输出 [0,1]) | Wxo、Wℎo、bo |
隐藏状态 ht | 短期记忆,用于当前输出 | tanh(输出[-1,1]) | Wyo、bo |
⊗ | 元素相乘 | 无 | 无 |
⊕ | 元素相加 | 无 | 无 |
σ 函数输出 [0,1],完美适配 “门控控制”(1 = 完全保留,0 = 完全过滤);tanh 确保信息值在合理范围
2.2 前向传播:5 步完成记忆更新
LSTM 的前向传播围绕 “记忆细胞更新” 展开,步骤清晰:遗忘门:决定 “丢什么”ft=σ(Wxf⋅Xt+Wℎf⋅ℎt−1+bf)例:ft=0.9→保留 90% 历史记忆Ct−1;ft=0.1→过滤 90% 旧信息。更新门 + 候选记忆:决定 “加什么”更新门:it=σ(Wxi⋅Xt+Wℎi⋅ht−1+bi)(控制新信息的权重)候选记忆:gt=tanh(Wxg⋅Xt+Wℎg⋅ht−1+bg)(当前时刻的新信息)更新记忆细胞:“丢旧 + 加新”Ct=Ct−1⊗ft+gt⊗it⊗为对应元素相乘,Ct同时承载 “长期历史Ct−1⊗ft” 和 “当前新信息gt⊗it”。输出门:决定 “输出什么”ot=σ(Wxo⋅Xt+Wℎo⋅ℎt−1+bo)生成隐藏状态与输出ht=ot⊗tanh (Ct)(tanh 将Ct缩至 [-1,1],再通过ot过滤)yt=Wyℎ⋅ℎt+by(最终预测结果,分类任务需加 Softmax)
2.3 反向传播与梯度推导
LSTM 的反向传播仍基于 BPTT,但需同时更新三门参数和记忆细胞相关梯度,核心是确保记忆细胞 Cₜ的梯度稳定传递。假设损失 Loss = L (yt,y't)(y't为真实标签),以下为关键梯度推导。
2.3.1 核心梯度 1:Loss 对记忆细胞 Cₜ的导数
记忆细胞Ct同时影响当前隐藏状态ℎt和下一时刻记忆细胞Ct+1,梯度公式为:

拆解导数项:∂Loss/∂ℎt:损失对隐藏状态的梯度,由输出层反向推导:

(包含当前输出和下一时刻四门的贡献)∂ℎt/∂Ct=ot⋅tanh'(Ct)(由ℎt=ot⊗tanh (Ct)求导)∂Ct+1/∂Ct=ft+1(由Ct+1=Ct⊗ft+1+gt+1⊗it+1求导)最终递推公式:

2.3.2 核心梯度 2:Loss 对门控参数的导数(以遗忘门为例)遗忘门参数(Wxf、Whf、bf)的梯度需通过链式法则推导:先求 Loss 对遗忘门输出ft的梯度:

再求 Loss 对遗忘门权重 Wxf的梯度:

(σ函数导数为σ(x)・(1-σ(x)),此处 ft=σ(...),故∂ft/∂Wxf⋅ft⋅(1−ft)⋅Xt)同理,Loss对Whf的梯度:

更新门、输出门、候选记忆的参数梯度推导逻辑一致,最终所有参数通过梯度下降(如 Adam 优化器)更新。2.3.3 LSTM 如何缓解梯度消失?对比 RNN 的梯度路径,LSTM 的记忆细胞梯度传递具有决定性优势:从∂Loss∂Ct的递推公式可见,当模型需要保留长期信息时,会通过参数学习使遗忘门ft+1≈1,此时:

由于 tanh'(Ct)∈[0,1],ot∈[0,1],但核心是∂Loss/∂Ct+1直接传递到∂Loss/∂Ct,无指数级衰减。即使序列长度达到 100+,远时刻(如 t=1)的梯度仍能稳定传递到当前时刻(如 t=100),从而有效捕捉长期依赖。
关键补充:模型如何 “学习” 让ft+1≈1?
遗忘门ft+1的输出由以下公式决定:ft+1=σ(Wxf⋅Xt+1+Wℎf⋅ℎt+bf)
其中σ是 Sigmoid 函数,当输入值>2 时,σ(x)≈0.95(接近1)。模型通过以下两种方式学习让ft+1≈1:
初始化阶段:设置遗忘门偏置 bf>0
工程实践中,会将遗忘门的偏置bf初始化为1~2(而非默认0),此时即使Wxf⋅Xt+1+Wℎf⋅ℎt=0,ft+1=σ(bf)≈0.73(已较高),为后续学习 “保留长期信息” 奠定基础。训练阶段:通过损失反向调整参数当模型因 “未保留远时刻信息” 导致 Loss 升高时,反向传播会调整Wxf、Wℎf、bf的取值:若 t=1 的信息对 t=100 的预测很重要,但当前f2=0.1(过滤了 t=1 的信息),则 Loss 会增大;反向传播时,∂Loss/∂f2为正值(增加f2可降低 Loss),进而通过∂Loss/∂Wf调整权重,使f2增大;反复迭代后,模型会学习到 “对重要的长期信息,让ft+1≈1。
三、RNN vs LSTM:怎么选?
两大模型各有优劣,需结合场景匹配:
维度 | 循环神经网络(RNN) | 长短期记忆网络(LSTM) |
记忆能力 | 仅短期依赖 | 长期依赖(序列长度 100+) |
梯度问题 | 易出现梯度消失,远时刻信息丢失 | 记忆细胞梯度稳定,缓解梯度消失 |
模型复杂度 | 低(仅 3 组核心参数:Wₓ、Wₛ、Wₒ) | 高(9 组核心参数:3 门 ×3 组权重 + 输出层权重) |
参数数量 | 少(如隐藏层维度 H=128,输入维度 D=64,参数量≈128²+128×64=24576) | 多(同上述维度,参数量≈4×(128²+128×64)=98304,约为 RNN 的 4 倍) |
计算效率 | 快(前向 / 反向传播步骤少) | 慢(门控计算多) |
训练难度 | 低(参数少,收敛快,易实现) | 高(参数多,易过拟合,需更多数据和正则化) |
核心优势 | 结构简单、训练速度快、资源占用低 | 鲁棒性强、长期依赖捕捉能力突出、任务精度高 |
四、工程实践小贴士4.1 模型选择策略先简后繁:先用 RNN 验证短序列任务可行性,若精度不达标(如测试集准确率 < 85%),再替换为 LSTM;折中方案:若 LSTM 计算压力大,可选用 GRU(门控循环单元)—— 简化为重置门和更新门 2 个门,参数量比 LSTM 少 25%,性能接近 LSTM;
数据适配:若序列长度差异大(如文本长度 5-200 词),可采用 “截断 + 填充”(固定序列长度)或 “动态批处理”(同批次序列长度一致)。4.2 LSTM 性能优化技巧参数裁剪:隐藏层维度从 256 降至 128,参数量减少 75%,训练速度提升 2-3 倍;序列分段:将长序列(如 1000 帧音频)拆分为 10 个 100 帧子序列,采用 “滚动预测” 拼接结果;量化训练:将 32 位浮点数参数转为 16 位半精度,显存占用减少 50%,推理速度提升 1.5 倍;正则化:添加 Dropout(隐藏层 dropout 率 0.2-0.5)、L2 正则化(权重衰减系数 1e-4),缓解过拟合。
4.3 常见问题排查
问题现象 | 可能原因 | 解决方案 |
训练 loss 不下降 | 1. 学习率过高 / 过低2. 梯度消失(LSTM 遗忘门ft过小) | 1. 调整学习率(如 Adam优化器默认 0.001,可尝试0.0001-0.01)2. 初始化遗忘门偏置bf为1-2(使ft初始值接近 1) |
测试集 loss 波动大 | 1. 数据量不足2. 序列长度分布不均 | 1. 数据增强(如文本同义词替换、时序数据加噪)2. 按序列长度分组训练,平衡各长度样本占比 |
总结RNN 作为序列建模的 “基石”,以简单的循环结构开创了历史信息复用的思路,但受限于梯度消失无法处理长序列;LSTM 则通过记忆细胞和门控机制的创新,从梯度传递路径上解决了长期依赖问题,成为长序列任务的经典方案。尽管当前 Transformer(如 BERT、GPT)在多数序列任务中表现更优,但 RNN 和 LSTM 的核心思想(时序关联捕捉、选择性记忆)仍是理解复杂序列模型的基础,也是 AI 工程师在资源受限场景下的重要选择。你在项目中用过 RNN 或 LSTM 吗?遇到过哪些训练难题?欢迎在评论区分享你的实践经验!
本文转自:秦芯智算
-
神经网络
+关注
关注
42文章
4829浏览量
106807 -
rnn
+关注
关注
0文章
92浏览量
7301 -
LSTM
+关注
关注
0文章
63浏览量
4296
发布评论请先 登录

一文读懂LSTM与RNN:从原理到实战,掌握序列建模核心技术
评论