手撕经典算法 #5 RLHF 篇
本文最后更新于:2025年4月11日 晚上
本文对 RLHF 中经典的算法进行了简单的实现和注释。包括:
- 广义优势估计(GAE)
- 各种损失函数(PPO、DPO、GRPO)
GAE
PPO 中的广义优势估计(Generalized Advantage Estimation,GAE) 通过在时间步上对 TD Error 进行加权累加,提供了一个在偏差和方差之间可调的优势估计器。其定义为: \[ \begin{aligned} A_t^{\text{GAE}(\gamma, \lambda)} &= \delta_t + (\gamma \lambda) \delta_{t+1} + (\gamma \lambda)^2 \delta_{t+2} + \cdots \\ &= \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l}\\ \end{aligned} \] 其中:
- \(\delta_t\) 是第 \(t\) 步的 TD Error:\(\delta_t = r_t + \gamma v_t(s_{t+1}) - v_t(s_t)\),这里为了简洁将第 \(t\) 步的即时奖励记为 \(r_t\);
- \(\gamma\) 是折扣因子,\(\lambda \in [0, 1]\) 是 GAE 的衰减系数,控制偏差和方差之间的平衡。
通过调整 \(\lambda\) 的值,可以在偏差和方差之间进行调节:
- 当 \(\lambda = 0\) 时,只考虑一步的 TD Error,偏差较大,但方差较小。
- 当 \(\lambda = 1\) 时,优势估计等价于蒙特卡洛方法,偏差较小,但方差较大。
在 PPO 实现的过程中,我们需要对每个 Token 都计算出 GAE,因此需要反向遍历,用递推的形式逐步计算。实现代码如下:
1 |
|
PPO
在真实的 PPO 代码中,可能包含策略模型损失函数、价值模型损失函数、语言模型损失函数,用于优化 Actor、Critic 和防止模型遗忘预训练知识。这里我们仅介绍常见的 PolicyLoss 和 ValueLoss。
Policy Loss
策略模型损失函数可以表示为: \[ L^{\mathrm{CLIP}}(\theta) = \mathbb{E}_{t} \left[ \min \left( r_t(\theta) A_t,\; \operatorname{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) A_t \right) \right] \] 用于优化 Actor 模型,为确保策略更新不会偏离旧策略太远,通过计算新旧策略的概率比率,并使用裁剪参数 \(\epsilon\) 与裁剪函数 clamp
限制更新幅度,从而稳定训练过程。
1 |
|
注意:这里的
log_probs
表示采样的 Experience 中选中 Action 的对数概率,经过一层 Log 变换:
1
2
3
4
5
6
7
8
# 模型输出原始 logits (batch_size, seq_len, vocab_size)
logits = actor_model(input_ids)
# 计算概率分布(softmax 后的概率)
probs = F.softmax(logits, dim=-1) # (batch_size, seq_len, vocab_size)
# 提取被采样 token 的 log 概率
log_probs = torch.gather(probs.log(), dim=-1, index=actions.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)
Value Loss
价值模型损失函数可以表示为: \[ L^{\text{Critic}}(\phi) = \mathbb{E}_{(x, y) \sim \mathcal{D}} \left[ \| V_\phi(x) - V_t^{\text{target}} \|^2 \right] \]
其中,\(V_\phi(x)\) 是价值模型对状态 \(x\) 的价值估计,\(V_t^{\text{target}}\) 在 RLHF-PPO 中一般定义为回报值 \(R_t = \hat{A}_{t}^{\mathrm{GAE}(\gamma,\lambda)} + V(s_t)\)。
虽然在 PPO 原论文中并未对价值函数引入裁剪,但在一些 PPO 的变体中,为了提高训练的稳定性,也在价值函数的训练中引入了裁剪操作。同样,通过参数 \(\epsilon\) ,约束新策略 \(V_\phi^{new}\) 与旧策略 \(V_\phi^{old}\) 之间的变化幅度。
实现代码如下:
1 |
|
DPO Loss
通过极大似然估计,BT 模型的优化目标可以转化为二元交叉熵损失: \[ \mathcal{L}_{\text{DPO}}(\theta) = -\mathbb{E}_{(x,y_w,y_l) \sim D} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta \log \frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)} \right) \right] \]
注意,DPO的目标是比较两个完整序列的整体偏好差异,而不是逐个token的生成质量。因此:被选中/拒绝响应的对数概率是对序列所有token的对数概率求和或者取平均,维度为 (batch_size,)
。
1 |
|
GRPO Loss
Reward 归一化处理
对于每个问题 \(q\) ,首先从旧策略模型 \(\pi_{o_\mathrm{old}}\) 中采样生成一组包含 \(G\) 个输出的结果 \(\{o_1,o_2,\cdots,o_G\}\)。利用奖励模型对这些输出进行评分,得到对应的奖励值 \(r=\{r_1,r_2,\cdots,r_G\}\)。对奖励进行归一化处理,该输出中所有 token 的优势值 \(\hat{A}_{i,t}\) 设为归一化的奖励值,即\(\hat{A}_{i,t}=\tilde{r}_i\quad(\forall t \in o_i).\)
1 |
|
Loss 计算
GRPO 的核心改进在于消除价值函数,通过组内相对优势估计代替绝对优势评估。其目标函数结合 PPO 的剪切机制与 KL 惩罚项,具体实现可分为以下步骤:
1 |
|
无偏 KL 散度
为了有效地计算策略模型与参考模型之间的 KL 散度,我们可以从对数概率(log_prob)的角度来实现这一过程: \[ \mathrm{KL}\left(\pi_\theta^{\mathrm{RL}}(y \mid x) \parallel \pi^{\mathrm{SFT}}(y \mid x)\right) = \sum_{t} \left( \log \pi_\theta^{\mathrm{RL}}(y_t \mid x, y_{<t}) - \log \pi^{\mathrm{SFT}}(y_t \mid x, y_{<t}) \right) \] 具体步骤如下:
- 使用策略模型进行采样:对输入 \(x\) 生成的输出 \(y\)。
- 计算策略模型的对数概率(log_probs):计算策略模型在每个时间步(Token)的对数概率。
- 计算参考模型的对数概率(ref_log_probs):对相同的输入输出,使用参考模型计算逐 Token 的对数概率。
- 计算 KL 散度:通过对数概率之差累加得到策略模型与参考模型之间的 KL 散度。
- 整合奖励:Token 级的 KL 散度与奖励模型在输出 \(y\) 上序列级 Reward 组合,得到 \(r_{\text {total}}\)。
1 |
|
在实际使用的时候,KL 散度还有一种非负、无偏、低方差的估计方式:
\[ \mathrm{KL}\left(\pi_\theta^{\mathrm{RL}}(y \mid x) \parallel \pi^{\mathrm{SFT}}(y \mid x)\right) = \frac{\pi^{\mathrm{SFT}}(y_t \mid x, y_{<t})}{\pi_\theta^{\mathrm{RL}}(y_t \mid x, y_{<t})} - \log \frac{\pi^{\mathrm{SFT}}(y_t \mid x, y_{<t})}{\pi_\theta^{\mathrm{RL}}(y_t \mid x, y_{<t})} - 1 \]
1 |
|