0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

图解大模型RLHF系列之:人人都能看懂的PPO原理与源码解读

jf_pmFSk4VX 来源:GiantPandaCV 2024-01-14 11:19 次阅读

大家好,最近我又读了读RLHF的相关paper和一些开源实践,有了一些心得体会,整理成这篇文章。过去在RLHF的初学阶段,有一个问题最直接地困惑着我:

  • 如何在NLP语境下理解强化学习的框架?例如,我知道强化学习中有Agent、Environment、Reward、State等要素,但是在NLP语境中,它们指什么?语言模型又是如何根据奖励做更新的?

为了解答这个问题,我翻阅了很多资料,看了许多的公式推导,去研究RLHF的整体框架和loss设计。虽然吭吭哧哧地入门了,但是这个过程实在痛苦,最主要的原因是:理论的部分太多,直观的解释太少。

所以,在写这篇文章时,我直接从一个RLHF开源项目源码入手(deepspeed-chat),根据源码的实现细节,给出尽可能丰富的训练流程图,并对所有的公式给出直观的解释。希望可以帮助大家更具象地感受RLHF的训练流程。对于没有强化学习背景的朋友,也可以无痛阅读本文。关于RLHF,各家的开源代码间都会有一些差异,同时也不止PPO一种RLHF方式。感兴趣的朋友,也可以读读别家的源码,做一些对比。后续有时间,这个系列也会对各种RLHF方式进行比较。

整体内容如下:
【一、强化学习概述】
1.1 强化学习整体流程
1.2 价值函数
【二、NLP中的强化学习】
【三、RLHF中的四个重要角色】
3.1 Actor Model
3.2 Reference Model
3.3 Critic Model
3.4 Reward Model
【四、RLHF中的loss计算】
4.1 Actor loss
(1) 直观设计
(2) 引入优势
(3) 重新设计奖励函数
(4) 重新设计优势
(5)ppo_epoch: 引入新约束,提升训练效率
(6) Actor loss小结
【五、Critic loss】
(1) 实际收益优化
(2) 预估收益优化

一、强化学习概述

1. 强化学习整体流程

f6713c9a-b223-11ee-8b88-92fbcf53809c.png
  • 强化学习的两个实体:智能体(Agent)环境(Environment)
  • 强化学习中两个实体的交互:
    • 状态空间S:S即为State,指环境中所有可能状态的集合
    • 动作空间A:A即为Action,指智能体所有可能动作的集合
    • 奖励R:R即为Reward,指智能体在环境的某一状态下所获得的奖励。

以上图为例,智能体与环境的交互过程如下:

  • 在时刻,环境的状态为,达到这一状态所获得的奖励为
  • 智能体观测到与,采取相应动作
  • 智能体采取后,环境状态变为,得到相应的奖励

智能体在这个过程中学习,它的最终目标是:找到一个策略,这个策略根据当前观测到的环境状态和奖励反馈,来选择最佳的动作。

1.2 价值函数

在1.1中,我们谈到了奖励值,它表示环境进入状态下的即时奖励

但如果只考虑即时奖励,目光似乎太短浅了:当下的状态和动作会影响到未来的状态和动作,进而影响到未来的整体收益。

所以,一种更好的设计方式是:t时刻状态s的总收益 = 身处状态s能带来的即时收益 + 从状态s出发后能带来的未来收益。写成表达式就是:

其中:

  • :时刻的总收益,注意这个收益蕴涵了“即时”和“未来”的概念
  • :时刻的即时收益
  • :时刻的总收益,注意这个收益蕴涵了“即时”和“未来”的概念。而对来说就是“未来”。
  • :折扣因子。它决定了我们在多大程度上考虑将“未来收益”纳入“当下收益”。

注:在这里,我们不展开讨论RL中关于价值函数的一系列假设与推导,而是直接给出一个便于理解的简化结果,方便没有RL背景的朋友能倾注更多在“PPO策略具体怎么做”及“对PPO的直觉理解”上。

二、NLP中的强化学习

我们在第一部分介绍了通用强化学习的流程,那么我们要怎么把这个流程对应到NLP任务中呢?换句话说,NLP任务中的智能体、环境、状态、动作等等,都是指什么呢?

f68cc596-b223-11ee-8b88-92fbcf53809c.png

回想一下我们对NLP任务做强化学习(RLHF)的目的:我们希望给模型一个prompt,让模型能生成符合人类喜好的response。再回想一下gpt模型做推理的过程:每个时刻只产生一个token,即token是一个一个蹦出来的,先有上一个token,再有下一个token。

复习了这两点,现在我们可以更好解读上面这张图了:

  • 我们先喂给模型一个prompt,期望它能产出符合人类喜好的response

  • 在时刻,模型根据上文,产出一个token,这个token即对应着强化学习中的动作,我们记为。因此不难理解,在NLP语境下,强化学习任务的动作空间就对应着词表。

  • 在时刻,模型产出token 对应着的即时收益为总收益为(复习一下,蕴含着“即时收益”与“未来收益”两个内容)。这个收益即可以理解为“对人类喜好的衡量”。此刻,模型的状态从变为也就是从“上文”变成“上文 + 新产出的token”

  • 在NLP语境下,智能体是语言模型本身,环境则对应着它产出的语料

这样,我们就大致解释了NLP语境下的强化学习框架,不过针对上面这张图,你可能还有以下问题:

1)问题1:图中的下标是不是写得不太对?例如根据第一部分的介绍,应该对应着应该对应着,以此类推?

答:你说的对。但这里我们不用太纠结下标的问题,只需要记住在对应的response token位置,会产生相应的即时奖励和总收益即可。之所以用图中这样的下标,是更方便我们后续理解代码。

(2)问题2:我知道肯定是由语言模型产生的,那么是怎么来的呢,也是语言模型产生的吗?

答:先直接说结论,是由我们的语言模型产生的,则分别由另外两个模型来产生,在后文中我们会细说。

(3)问题3:语言模型的参数在什么时候更新?是观测到一个,就更新一次参数,然后再去产生吗?

答:当然不是。你只看到某个时刻的收益,就急着用它更新模型,这也太莽撞了。我们肯定是要等有足够的观测数据了(例如等模型把完整的response生成完),再去更新它的参数。这一点我们也放在后文细说。

(4)问题4:再谈谈吧,在NLP的语境下我还是不太理解它们。

答:

  • 首先,“收益”的含义是“对人类喜好的衡量”
  • :即时收益,指语言模型当下产生token 带来的收益
  • :实际期望总收益(即时+未来),指对语言模型“当下产生token ,一直到整个response生产结束”后的期收益预估。因为当下语言模型还没产出后的token,所以我们只是对它之后一系列动作的收益做了估计,因而称为“期望总收益”。

三、RLHF中的四个重要角色

在本节中,我们在第二部分的基础上更进一步:更详细理清NLP语境下RLHF的运作流程。

我们从第二部分中已经知道:生成token 和对应收益的并不是一个模型。那么在RLHF中到底有几个模型?他们是怎么配合做训练的?而我们最终要的是哪个模型?

f69ba4c6-b223-11ee-8b88-92fbcf53809c.png

如上图,在RLHF-PPO阶段,一共有四个主要模型,分别是:

  • Actor Model:演员模型,这就是我们想要训练的目标语言模型
  • Critic Model:评论家模型,它的作用是预估总收益
  • Reward Model:奖励模型,它的作用是计算即时收益
  • Reference Model:参考模型,它的作用是在RLHF阶段给语言模型增加一些“约束”,防止语言模型训歪(朝不受控制的方向更新,效果可能越来越差)

其中:

  • Actor/Critic Model在RLHF阶段是需要训练的(图中给这两个模型加了粗边,就是表示这个含义);而Reward/Reference Model参数冻结的。
  • Critic/Reward/Reference Model共同组成了一个“奖励-loss”计算体系(我自己命名的,为了方便理解),我们综合它们的结果计算loss,用于更新Actor和Critic Model

我们把这四个部分展开说说。

3.1 Actor Model (演员模型)

正如前文所说,Actor就是我们想要训练的目标语言模型。我们一般用SFT阶段产出的SFT模型来对它做初始化。

f6ac0bf4-b223-11ee-8b88-92fbcf53809c.png

我们的最终目的是让Actor模型能产生符合人类喜好的response。所以我们的策略是,先喂给Actor一条prompt(这里假设batch_size = 1,所以是1条prompt),让它生成对应的response。然后,我们再将“prompt + response"送入我们的“奖励-loss”计算体系中去算得最后的loss,用于更新actor。

3.2 Reference Model(参考模型)

Reference Model(以下简称Ref模型)一般也用SFT阶段得到的SFT模型做初始化,在训练过程中,它的参数是冻结的。Ref模型的主要作用是防止Actor”训歪”,那么它具体是怎么做到这一点的呢?

f6bcedfc-b223-11ee-8b88-92fbcf53809c.png

“防止模型训歪”换一个更详细的解释是:我们希望训练出来的Actor模型既能达到符合人类喜好的目的,又尽量让它和SFT模型不要差异太大。简言之,我们希望两个模型的输出分布尽量相似。那什么指标能用来衡量输出分布的相似度呢?我们自然而然想到了KL散度

如图所示:

  • 对Actor模型,我们喂给它一个prompt,它正常输出对应的response。那么response中每一个token肯定有它对应的log_prob结果呀,我们把这样的结果记为log_probs

  • 对Ref模型,我们把Actor生成的"prompt + response"喂给它,那么它同样能给出每个token的log_prob结果,我们记其为ref_log_probs

  • 那么这两个模型的输出分布相似度就可以用ref_log_probs - log_probs来衡量,我们可以从两个方面来理解这个公式:

    • 从直觉上理解,ref_log_probs越高,说明Ref模型对Actor模型输出的肯定性越大。即Ref模型也认为,对于某个,输出某个的概率也很高()。这时可以认为Actor模型较Ref模型没有训歪

    • 从KL散度上理解

      ,这个值越小意味着两个分布的相似性越高。而这个值越小等价于ref_log_probs - log_probs越大

注:你可能已经注意到,按照KL散度的定义,这里写成log_probs - ref_log_probs更合适一些。这里之所以写成ref_log_probs - log_probs,是为了方便大家从直觉上了解这个公式。

现在,我们已经知道怎么利用Ref模型和KL散度来防止Actor训歪了。KL散度将在后续被用于loss的计算,我们在后文中会详细解释。

3.3 Critic Model(评论家模型)

Critic Model用于预测期望总收益,和Actor模型一样,它需要做参数更新。实践中,Critic Model的设计和初始化方式也有很多种,例如和Actor共享部分参数、从RW阶段的Reward Model初始化而来等等。我们讲解时,和deepspeed-chat的实现保持一致:从RW阶段的Reward Model初始化而来。

你可能想问:训练Actor模型我能理解,但我还是不明白,为什么要单独训练一个Critic模型用于预测收益呢?

这是因为,当我们在前文讨论总收益(即时 + 未来)时,我们是站在上帝视角的,也就是这个就是客观存在的、真正的总收益。但是我们在训练模型时,就没有这个上帝视角加成了,也就是在时刻,我们给不出客观存在的总收益,我们只能训练一个模型去预测它。

所以总结来说,在RLHF中,我们不仅要训练模型生成符合人类喜好的内容的能力(Actor),也要提升模型对人类喜好量化判断的能力(Critic)。这就是Critic模型存在的意义。我们来看看它的大致架构:

f6ca783c-b223-11ee-8b88-92fbcf53809c.png

deepspeed-chat采用了Reward模型作为它的初始化,所以这里我们也按Reward模型的架构来简单画画它。你可以简单理解成,Reward/Critic模型和Actor模型的架构是很相似的(毕竟输入都一样),同时,它在最后一层增加了一个Value Head层,该层是个简单的线形层,用于将原始输出结果映射成单一的值。

在图中,表示Critic模型对时刻及未来(response完成)的收益预估。

3.4 Reward Model(奖励模型)

Reward Model用于计算生成token 的即时收益,它就是RW阶段所训练的奖励模型,在RLHF过程中,它的参数是冻结的。

你可能想问:为什么Critic模型要参与训练,而同样是和收益相关的Reward模型的参数就可以冻结呢?

这是因为,Reward模型是站在上帝视角的。这个上帝视角有两层含义:

  • 第一点,Reward模型是经过和“估算收益”相关的训练的,因此在RLHF阶段它可以直接被当作一个能产生客观值的模型。
  • 第二点,Reward模型代表的含义就是“即时收益”,你的token 已经产生,因此即时收益自然可以立刻算出。

你还可能想问:我已经用Critic预测出了,而这个包含了“即时”和“未来”的概念,那我还需要代表“即时”的做什么呢?直接用不就好了吗?

为了解答这个问题,我们先回顾下1.2部分中给出的价值函数:

这个函数告诉我们,我们当前可以用两个结果来表示时刻的总收益:

  • 结果1:Critic模型预测的

  • 结果2:Reward模型预测的和critic模型预测的

那么哪一个结果更靠近上帝视角给出的客观值呢?当然是结果2,因为结果1全靠预测,而结果2中的是事实数据。

我们知道Critic模型也是参与参数更新的,我们可以用MSE(上帝视角的客观收益-Critic模型预测的收益)来衡量它的loss。但是上帝视角的客观收益我们是不知道的,只能用已知事实数据去逼近它,所以我们就用来做近似。这就是同时存在的意义

Reward模型和critic模型非常相似,这里我们就只给出架构图,不再做过多的说明。关于Reward模型的训练过程,后续有时间也会出个原理和代码解析。

f6d4d836-b223-11ee-8b88-92fbcf53809c.png

四、RLHF中的loss计算

到目前为止,我们已经基本了解了RLHF的训练框架,以及其中的四个重要角色(训练一个RLHF,有4个模型在硬件上跑,可想而知对存储的压力)。在本节中,我们一起来解读RLHF的loss计算方式。在解读中,我们会再一次理一遍RLHF的整体训练过程,填补相关细节。在这之后,我们就可以来看代码解析了。

在第三部分的讲解中,我们知道Actor和Critic模型都会做参数更新,所以我们的loss也分成2个:

  • Actor loss:用于评估Actor是否产生了符合人类喜好的结果,将作用于Actor的BWD上。

  • Critic loss:用于评估Critic是否正确预测了人类的喜好,将作用于Critic的BWD上。

我们详细来看这两者。

4.1 Actor loss

(1)直观设计

我们先来看一个直观的loss设计方式:

  • Actor接收到当前上文,产出token()

  • Critic根据,产出对总收益的预测

  • 那么Actor loss可以设计为:

求和符号表示我们只考虑response部分所有token的loss,为了表达简便,我们先把这个求和符号略去(下文也是同理),也就是说:

我们希望minimize这个actor_loss。

这个设计的直观解释是:

  • 当时,意味着Critic对Actor当前采取的动作给了正向反馈,因此我们就需要在训练迭代中提高,这样就能达到减小loss的作用。
  • 当时,意味着Critic对Actor当前采取的动作给了负向反馈,因此我们就需要在训练迭代中降低,这样就能到达到减小loss的作用。

一句话总结:这个loss设计的含义是,对上文而言,如果token产生的收益较高,那就增大它出现的概率,否则降低它出现的概率。

(2)引入优势(Advantage)

在开始讲解之前,我们举个小例子:

假设在王者中,中路想支援发育路,这时中路有两种选择:1. 走自家野区。2. 走大龙路。

中路选择走大龙路,当她做出这个决定后,Critic告诉她可以收1个人头。结果,此刻对面打野正在自家采灵芝,对面也没有什么苟草英雄,中路一路直上,最终收割2个人头。

因为实际收割的人头比预期要多1个,中路尝到了甜头,所以她增大了“支援上路走大龙路”的概率。

这个多出来的“甜头”,就叫做“优势”(Advantage)。

对NLP任务来说,如果Critic对的总收益预测为,但实际执行后的总收益是,我们就定义优势为:

我们用替换掉,则此刻actor_loss变为:

(3)重新设计

总结一下,到目前为止,我们的actor_loss形式为:

其中,

同时注意,这个actor_loss应该是response的所有token loss的sum或者avg。这里为了表达方便,我们的公式略去了求和或求平均的符号。

按照这个理解,应该表示每个Actor产出token 带来的即时收益,正如下图所示(其中表示最后一个时刻):

f6e124e2-b223-11ee-8b88-92fbcf53809c.png

但在deepspeed-chat的RLHF实践中,对做了另一种设计:

  • :常量,可以理解成是一个控制比例的缩放因子,在deepspeed-chat中默认设为0.1

  • :这一项你是不是非常眼熟,这就是我们在3.2部分介绍的Actor和Ref模型间的KL散度呀,写成更容易理解的形式,就是ref_log_probs - log_probs。在3.2中我们说过,为了防止模型训歪,我们需要把这个KL散度加入loss计算中,所以这里我们就在做这件事

基于这些,上面这个对的设计可理解成:

  • 时,我们更加关心Actor是否有在Ref的约束下生产token
  • 时,我们不仅关心Actor是否遵从了Ref的约束,也关心真正的即时收益

需要注意的是,的设计并不只有这一种(其实我觉得只取时刻的不太合理,deepspeed在自己的代码注释中也有提过,可以尝试把最后一个时刻的替换成所有token的平均值;或者在我来看时也应该纳入对的考虑)。

代码实践如下:

defcompute_rewards(self,prompts,log_probs,ref_log_probs,reward_score,
action_mask):
"""
 reward_function:计算最终的reward分数
复习一下几个相关参数的默认值:
self.kl_ctl=0.1
self.clip_reward_value=5

对于batch中的某个prompt来说,它最终的reward分数为:
(1)先计算actor和ref_model的logit相似度:-self.kl_ctl *(log_probs - ref_log_probs)
其实写成self.kl_ctl*(ref_log_probs-log_probs)更好理解些
这个值越大,说明ref_model对actor生成的结果的认可度越高(即表明rlhf没有训歪),
没有训歪的情况下我们也应该给模型一些奖励,这个奖励就是self.kl_ctl*(ref_log_probs-log_probs)

(2)由于我们只取最后一个token对应位置的分数作为reward_score,因此我们只需要:
self.kl_ctl*(ref_log_probs-log_probs)的最后一位+reward_score

(3)同时我们对reward_score也做了大小限制,最大不超过self.clip_reward_value(超过统一给成self.clip_reward_value),
最小不低于-self.clip_reward_value(低于统一给成-self.clip_reward_value)

(4)最后返回的rewards大小为:(batch_size, 各条数据的长度),对batch中的每条数据来说:
- response的最后一位:self.kl_ctl *(ref_log_probs - log_probs)的最后一位+ reward_score
- response的其余位置:self.kl_ctl *(ref_log_probs - log_probs)

"""

kl_divergence_estimate=-self.kl_ctl*(log_probs-ref_log_probs)
rewards=kl_divergence_estimate
#---------------------------------------------------------------------------------------------------
#response开始的位置
#(因为我们对prompt做过padding处理,因此batch中每个prompt长度一致,也就意味着每个response开始的位置一致)
#(所以这里start是不加s的,只是一个int)
#---------------------------------------------------------------------------------------------------
start=prompts.shape[1]-1
#---------------------------------------------------------------------------------------------------
#response结束的位置
#(因为一个batch中,每个response的长度不一样,所以response的结束位置也不一样)
#(所以这里end是加s的,ends的尺寸是(batch_size,)
#---------------------------------------------------------------------------------------------------
ends=start+action_mask[:,start:].sum(1)+1
#---------------------------------------------------------------------------------------------------
#对rewards_score做限制
#---------------------------------------------------------------------------------------------------
reward_clip=torch.clamp(reward_score,-self.clip_reward_value,
self.clip_reward_value)
batch_size=log_probs.shape[0]
forjinrange(batch_size):
rewards[j,start:ends[j]][-1]+=reward_clip[j]#

returnrewards

(4)重新设计优势

好,再总结一下,目前为止我们的actor_loss为:

其中,

同时,我们对进行来改造,使其能够衡量Actor模型是否遵从了Ref模型的约束。

现在我们把改造焦点放在上,回想一下,既然对于收益而言,分为即时和未来,那么对于优势而言,是不是也能引入对未来优势的考量呢?这样,我们就可以把改写成如下形式:

(熟悉强化学习的朋友应该能一眼看出这是GAE,这里我们不打算做复杂的介绍,一切都站在直觉的角度理解)

其中,新引入的也是一个常量,可将其理解为权衡因子,直觉上看它控制了在计算当前优势时对未来优势的考量。(从强化学习的角度上,它控制了优势估计的方差和偏差)

看到这里,你可能想问:这个代表未来优势的,我要怎么算呢?

注意到,对于最后一个时刻,它的未来收益()和未来优势()都是0,也就是,这是可以直接算出来的。而有了,我们不就能从后往前,通过动态规划的方法,把所有时刻的优势都依次算出来了吗?

代码实践如下(其中返回值中的returns表示实际收益,将被用于计算Critic模型的loss,可以参见4.2,其余细节都在代码注释中):

defget_advantages_and_returns(self,values,rewards,start):
"""
Adoptedfromhttps://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134

没有引入GAE前的t时刻的优势值:
detal_t=r_t+gamma*V_t+1-V_t
其中:
-r_t表示t时刻的即时收益
-V_t+1表示未来时刻的预期收益
-r_t+gamma*V_t+1可理解成t时刻的实际预期收益
-V_t可理解成t时刻的预估预期收益(是模型,例如criticmodel自己估算出来的)

引入GAE后的t时刻的优势值:
A_t=delta_t+gamma*lambda*A_t+1
粗暴理解为在t时刻时,不仅考虑当下优势,还考虑了未来的优势
为了知道A_t, 我们得知道A_t+1,所以在本算法中采取了从后往前做动态规划求解的方法,也即:
假设T是最后一个时刻,则有A_T+1=0,所以有:A_T=delta_T
知道了A_T,就可以依次往前倒推,把A_t-1,A_t-2之类都算出来了

引入GAE后t时刻的实际预期收益
returns_t=A_t+V_t
=delta_t+gamma*lambda*A_t+1+V_t
=r_t+gamma*V_t+1-V_t+gamma*lambda*A_t+1+V_t
=r_t+gamma*(V_t+1+lambda*A_t+1)

注意,这里不管是advantages还是returns,都只算response的部分
"""

#Adoptedfromhttps://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
lastgaelam=0
advantages_reversed=[]
length=rewards.size()[-1]
#注意这里用了reversed,是采取从后往前倒推计算的方式
fortinreversed(range(start,length)):
nextvalues=values[:,t+1]ift< length - 1else0.0
delta=rewards[:,t]+self.gamma*nextvalues-values[:,t]
lastgaelam=delta+self.gamma*self.lam*lastgaelam
advantages_reversed.append(lastgaelam)
advantages=torch.stack(advantages_reversed[::-1],dim=1)#优势
returns=advantages+values[:,start:]#实际收益
#values:预期收益
returnadvantages.detach(),returns

(5) PPO-epoch: 引入新约束

总结一下,目前为止我们的actor_loss为:

其中,

同时

  • 我们已经对进行来改造,使其能够衡量Actor模型是否遵从了Ref模型的约束。
  • 我们已经对进行改造,使其不仅考虑了当前时刻的优势,还考虑了未来的优势

基于这些改造,我们重新理一遍RLHF-PPO的训练过程。

f6f2920e-b223-11ee-8b88-92fbcf53809c.png
  • 第一步,我们准备一个batch的prompts

  • 第二步,我们将这个batch的prompts喂给Actor模型,让它生成对应的responses

  • 第三步,我们把prompt+responses喂给我们的Critic/Reward/Reference模型,让它生成用于计算actor/critic loss的数据,按照强化学习的术语,我们称这些数据为经验(experiences)。critic loss我们将在后文做详细讲解,目前我们只把目光聚焦到actor loss上

  • 第四步,我们根据这些经验,实际计算出actor/critic loss,然后更新Actor和Critic模型

这些步骤都很符合直觉,但是细心的你肯定发现了,文字描述中的第四步和图例中的第四步有差异:图中说,这一个batch的经验值将被用于n次模型更新,这是什么意思呢?

我们知道,在强化学习中,收集一个batch的经验是非常耗时的。对应到我们RLHF的例子中,收集一次经验,它要等四个模型做完推理才可以,正是因此,一个batch的经验,只用于计算1次loss,更新1次Actor和Critic模型,好像有点太浪费了。

所以,我们自然而然想到,1个batch的经验,能不能用来计算ppo-epochs次loss,更新ppo-epochs次Actor和Critic模型?简单写一下伪代码,我们想要:

#--------------------------------------------------------------
#初始化RLHF中的四个模型
#--------------------------------------------------------------
actor,critic,reward,ref=initialize_models()

#--------------------------------------------------------------
#训练
#--------------------------------------------------------------
#对于每一个batch的数据
foriinsteps:
#先收集经验值
exps=generate_experience(prompts,actor,critic,reward,ref)
#一个batch的经验值将被用于计算ppo_epochs次loss,更新ppo_epochs次模型
#这也意味着,当你计算一次新loss时,你用的是更新后的模型
forjinppo_epochs:
actor_loss=cal_actor_loss(exps,actor)
critic_loss=cal_critic_loss(exps,critic)

actor.backward(actor_loss)
actor.step()

critc.backward(critic_loss)
critic.step()

而如果我们想让一个batch的经验值被重复使用ppo_epochs次,等价于我们想要Actor在这个过程中,模拟和环境交互ppo_epochs次。举个例子:

  • 如果1个batch的经验值只使用1次,那么在本次更新完后,Actor就吃新的batch,正常和环境交互,产出新的经验值

  • 但如果1个batch的经验值被使用ppo_epochs次,在这ppo_epochs中,Actor是不吃任何新数据,不做任何交互的,所以我们只能让Actor“模拟”一下和环境交互的过程,吐出一些新数据出来。

那怎么让Actor模拟呢?很简单,让它观察一下之前的数据长什么样,让它依葫芦画瓢,不就行了吗?我们假设最开始吃batch,吐出经验的actor叫,而在伪代码中,每次做完ppo_epochs而更新的actor叫,那么我们只要尽量保证每次更新后的能模仿最开始的那个,不就行了吗?

诶!是不是很眼熟!两个分布,通过什么方法让它们相近!那当然是KL散度!所以,再回到我们的actor_loss上来,它现在就可被改进成:

我们再稍作一些改动将log去掉(这个其实不是“稍作改动去掉log”的事,是涉及到PPO中重要性采样的相关内容,大家有兴趣可以参考https://www.cnblogs.com/xingzheai/p/15931681.html):

其中,表示真正吃了batch,产出经验值的Actor;P表示ppo_epochs中实时迭代更新的Actor,它在模仿的行为。所以这个公式从直觉上也可以理解成:在Actor想通过模拟交互的方式,使用一个batch的经验值更新自己时,它需要收到真正吃到batch的那个时刻的Actor的约束,这样才能在有效利用batch,提升训练速度的基础上,保持训练的稳定。


但是,谨慎的你可能此时又有新的担心了:虽然我们在更新Actor的过程中用做了约束,但如的约束能力不够,比如说还是超出了可接受的范围,那怎么办?

很简单,那就剪裁(clip)它吧!

我们给设置一个范围,例如(0.8 ,1.2),也就是如果这个值一旦超过1.2,那就统一变成1.2;一旦小于0.8,那就统一变成0.8。这样就能保证和的分布相似性在我们的掌控之内了。此时actor_loss变为:

这时要注意,如果超过变化范围,将强制设定为一个常数后,就说明这一部分的loss和Actor模型无关了,而这项本身也与Actor无关。所以相当于,在超过约束范围时,我们停止对Actor模型进行更新。

整体代码如下:

defactor_loss_fn(self,logprobs,old_logprobs,advantages,mask):
"""
logprobs:实时计算的,response部分的prob(只有这个是随着actor实时更新而改变的)
 old_logprobs:老策略中,response部分的prob (这个是固定的,不随actor实时更新而改变)
 advantages:老策略中,response部分每个token对应的优势(这个是固定的,不随actor实时更新而改变)
 mask:老策略中,response部分对应的mask情况这个是固定的,不随actor实时更新而改变)

之所以要引入logprobs计算actor_loss,是因为我们不希望策略每次更新的幅度太大,防止模型训歪

self.cliprange:默认值是0.2
"""
##policygradientloss
#-------------------------------------------------------------------------------------
#计算新旧策略间的KL散度
#-------------------------------------------------------------------------------------
log_ratio=(logprobs-old_logprobs)*mask
ratio=torch.exp(log_ratio)
#-------------------------------------------------------------------------------------
#计算原始loss和截断loss
#-------------------------------------------------------------------------------------
pg_loss1=-advantages*ratio
pg_loss2=-advantages*torch.clamp(ratio,1.0-self.cliprange,1.0+self.cliprange)
pg_loss=torch.sum(torch.max(pg_loss1,pg_loss2)*mask)/mask.sum()#最后是取每个非mask的responsetoken的平均loss作为最终loss
returnpg_loss

(6)Actor loss小结

(1)~(5)中我们一步步树立了actor_loss的改进过程,这里我们就做一个总结吧:

其中:

  • 我们已经对进行来改造,使其能够衡量Actor模型是否遵从了Ref模型的约束

  • 我们已经对进行改造,使其不仅考虑了当前时刻的优势,还考虑了未来的优势

  • 我们重复利用了1个batch的数据,使本来只能被用来做1次模型更新的它现在能被用来做ppo_epochs次模型更新。我们使用真正吃了batch,产出经验值的那个时刻的Actor分布来约束ppo_epochs中更新的Actor分布

  • 我们考虑了剪裁机制(clip),在ppo_epochs次更新中,一旦Actor的更新幅度超过我们的控制范围,则不对它进行参数更新。

4.2 Critic loss

我们知道,1个batch产出的经验值,不仅被用来更新Actor,还被用来更新Critic。对于Critic loss,我们不再像Actor loss一样给出一个“演变过程”的解读,我们直接来看它最后的设计。

首先,在之前的解说中,你可能有这样一个印象:

  • :Critic对t时刻的总收益的预估,这个总收益包含即时和未来的概念(预估收益)

  • :Reward计算出的即时收益,Critic预测出的及之后时候的收益的折现,这是比更接近t时刻真值总收益的一个值(实际收益)

所以,我们的第一想法是:

现在,我们对“实际收益”和“预估收益”都做一些优化。

(1)实际收益优化

我们原始的实际收益为,但是当我们在actor_loss中引入“优势”的概念时,“优势”中刻画了更为丰富的实时收益信息,所以,我们将实际收益优化为:

(2)预估收益优化

我们原始的预估收益为。

类比于Actor,Critic模型在ppo_epochs的过程中也是不断更新的。所以这个可以理解成是,也就是真正吃了batch,参与产出经验的那个时候的Critic产出的收益预测结果。

我们同样想用旧模型去约束新模型,但对于Critic我们采用的约束策略就比较简单了,我们直接看代码,从中可以看出,我们用老设计了了一个变动范围,然后用这个变动范围去约束新

#self.cliprange_value是一个常量
#old_values:老critic的预测结果
# values:新critic的预测结果
values_clipped=torch.clamp(
values,
old_values-self.cliprange_value,
old_values+self.cliprange_value,
)

那么最终我们就取实际收益和预估收益的MSE做为loss就好,这里注意,计算实际收益时都是老Critic(真正吃了batch的那个)产出的结果,而预估收益是随着ppo_epochs而变动的。

代码如下:

defcritic_loss_fn(self,values,old_values,returns,mask):
"""
values:实时critic跑出来的预估预期收益(是变动的,随着ppoepoch迭代而改变)
 old_values:老critic跑出来的预估预期收益(是固定值)
 returns:实际预期收益
 mask:response部分的mask

self.cliprange_value=0.2
"""
##valueloss
#用旧的value去约束新的value
values_clipped=torch.clamp(
values,
old_values-self.cliprange_value,
old_values+self.cliprange_value,
)
ifself.compute_fp32_loss:
values=values.float()
values_clipped=values_clipped.float()

#critic模型的loss定义为(预估预期收益-实际预期收益)**2
vf_loss1=(values-returns)**2
vf_loss2=(values_clipped-returns)**2
vf_loss=0.5*torch.sum(
torch.max(vf_loss1,vf_loss2)*mask)/mask.sum()#同样,最后也是把criticloss平均到每个token上
returnvf_loss

至此,关于RLHF-PPO训练的核心部分和代码解读就讲完了,建议大家亲自阅读源码,如果有硬件条件,可以动手实践。源码地址:

https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning


声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 源码
    +关注

    关注

    8

    文章

    574

    浏览量

    28595
  • 强化学习
    +关注

    关注

    4

    文章

    259

    浏览量

    11115
  • 大模型
    +关注

    关注

    2

    文章

    1555

    浏览量

    1148

原文标题:图解大模型RLHF系列之:人人都能看懂的PPO原理与源码解读

文章出处:【微信号:GiantPandaCV,微信公众号:GiantPandaCV】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    微软4月1日推出生成式AI安全产品“Securit Copilot”

    脚本反编程:自动解码恶意软件,实现无须手工逆向工程,让每位分析师都能看懂入侵者的操作;分析复杂命令行脚本,实现自然语言解释,找到相关实体并关联;
    的头像 发表于 03-14 10:28 169次阅读

    程序中的R地址都是什么意思?怎么样才能看懂

    程序中的R地址都是什么意思?怎么样才能看懂? 在程序中,R地址通常指的是寄存器地址,它是用来存储和访问计算机中的数据的硬件部件。寄存器是计算机中最快的内存形式,它位于中央处理器(CPU)内部或与
    的头像 发表于 02-18 10:49 314次阅读

    OneFlow Softmax算子源码解读之BlockSoftmax

    写在前面:笔者这段时间工作太忙,身心俱疲,博客停更了一段时间,现在重新捡起来。本文主要解读 OneFlow 框架的第二种 Softmax 源码实现细节,即 block 级别的 Softmax。
    的头像 发表于 01-08 09:26 324次阅读
    OneFlow Softmax算子<b class='flag-5'>源码</b><b class='flag-5'>解读</b>之BlockSoftmax

    2022年汽车企业芯片版图解读

    电子发烧友网站提供《2022年汽车企业芯片版图解读.rar》资料免费下载
    发表于 12-13 10:47 0次下载
    2022年汽车企业芯片版<b class='flag-5'>图解读</b>

    拆解大语言模型RLHF中的PPO算法

    由于本文以大语言模型 RLHFPPO 算法为主,所以希望你在阅读前先弄明白大语言模型 RLHF 的前两步,即 SFT Model 和
    的头像 发表于 12-11 18:30 1302次阅读
    拆解大语言<b class='flag-5'>模型</b><b class='flag-5'>RLHF</b>中的<b class='flag-5'>PPO</b>算法

    小白都能看懂的云计算入门热文

    2006 年 8 月 9 日,当时的谷歌首席执行官埃里克·施密特(Eric Schmidt)在搜索引擎大会(SES San Jose 2006)上,首次提出了“云计算”(Cloud Computing)的概念。 而就在大会的 5 个月之前,2006 年 3 月,电商起家的美国亚马逊公司正式推出了自家的弹性计算云(Elastic Compute Cloud,EC2)服务。 这两个标志性事件的发生,正式宣告了云计算时代的到来,也意味着互联网的发展进入了一个新的阶段。 时至今日,十七年过去了,云计算经历了质疑,也经历了狂热,逐
    的头像 发表于 11-09 11:37 241次阅读

    求助,求个示波器上位机的源码

    最好是vcc#实在是有点儿不会。vcVc多少还能看懂一些 下位机程序。已经没有什么问题。现在卡在上位机上了 自己试着做了一个源码但是在核心的问题上,不知道该怎么处理 。 就是这个屏幕打点方法测试
    发表于 10-25 08:31

    电路图你也能看懂

    电子发烧友网站提供《电路图你也能看懂.pdf》资料免费下载
    发表于 10-07 09:58 20次下载

    大语言模型(LLM)预训练数据集调研分析

    model 训练完成后,使用 instruction 以及其他高质量的私域数据集来提升 LLM 在特定领域的性能;而 rlhf 是 openAI 用来让model 对齐人类价值观的一种强大技术;pre-training dataset 是大模型在训练时真正喂给 mode
    的头像 发表于 09-19 10:00 576次阅读
    大语言<b class='flag-5'>模型</b>(LLM)预训练数据集调研分析

    什么是运放 反相比例运放电路图

     只要记住Uo = A * (Up-Un)和“虚短”、“虚断”,理想运放的电路都能看懂。这里先不要纠结为什么会是这样,有机会后面会介绍。这里先介绍一个最简单的运放电路:反相比例放大电路。
    发表于 09-03 10:58 961次阅读
    什么是运放 反相比例运放电路图

    RLHF实践中的框架使用与一些坑 (TRL, LMFlow)

    我们主要用一个具体的例子展示如何在两个框架下做RLHF,并且记录下训练过程中我们踩到的主要的坑。这个例子包括完整的SFT,奖励建模和 RLHF, 其中RLHF包括通过 RAFT 算法(Reward rAnked FineTuni
    的头像 发表于 06-20 14:36 1489次阅读
    <b class='flag-5'>RLHF</b>实践中的框架使用与一些坑 (TRL, LMFlow)

    图解模型训练之:Megatron源码解读2,模型并行

    前文说过,用Megatron做分布式训练的开源大模型有很多,我们选用的是THUDM开源的CodeGeeX(代码生成式大模型,类比于openAI Codex)。选用它的原因是“完全开源”与“清晰的模型架构和预训练配置图”,能帮助我
    的头像 发表于 06-07 15:08 2625次阅读
    <b class='flag-5'>图解</b>大<b class='flag-5'>模型</b>训练之:Megatron<b class='flag-5'>源码</b><b class='flag-5'>解读</b>2,<b class='flag-5'>模型</b>并行

    图解模型系列之:Megatron源码解读1,分布式环境初始化

    使用Megatron来训练gpt类大模型的项目有很多。在这个系列里,我选择了由THUDM开发的CodeGeeX项目,它是gpt在代码生成方向上的应用,对标于openAI的CodeX。github地址在此。
    的头像 发表于 06-06 15:22 4114次阅读
    <b class='flag-5'>图解</b>大<b class='flag-5'>模型</b><b class='flag-5'>系列</b>之:Megatron<b class='flag-5'>源码</b><b class='flag-5'>解读</b>1,分布式环境初始化

    草履虫都能看得明白的FOC 入门教程

    草履虫都能看得明白的FOC 入门教程 其利天下技开发 其利天下技开发 *附件:FOC技术笔记-新修版.pdf
    发表于 05-29 12:05

    这几个基础模块电路,你能看懂吗?

    文章开始前,先来考考大家~ 下面的五副电路图,你能看懂几个? TDA2030电路图 34063电路图 555电路 TDA2030电路图 三极管分立元件电路 以上这些电路图,如果你能够看懂,那说明已经
    的头像 发表于 05-20 09:00 563次阅读
    这几个基础模块电路,你<b class='flag-5'>能看懂</b>吗?