手撕经典算法 #5 RLHF 篇

本文最后更新于:2025年4月11日 晚上

本文对 RLHF 中经典的算法进行了简单的实现和注释。包括:

  • 广义优势估计(GAE)
  • 各种损失函数(PPO、DPO、GRPO)

GAE

PPO 中的广义优势估计(Generalized Advantage Estimation,GAE) 通过在时间步上对 TD Error 进行加权累加,提供了一个在偏差和方差之间可调的优势估计器。其定义为: \[ \begin{aligned} A_t^{\text{GAE}(\gamma, \lambda)} &= \delta_t + (\gamma \lambda) \delta_{t+1} + (\gamma \lambda)^2 \delta_{t+2} + \cdots \\ &= \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l}\\ \end{aligned} \] 其中:

  • \(\delta_t\) 是第 \(t\) 步的 TD Error:\(\delta_t = r_t + \gamma v_t(s_{t+1}) - v_t(s_t)\),这里为了简洁将第 \(t\) 步的即时奖励记为 \(r_t\)
  • \(\gamma\) 是折扣因子,\(\lambda \in [0, 1]\)GAE 的衰减系数,控制偏差和方差之间的平衡

通过调整 \(\lambda\) 的值,可以在偏差和方差之间进行调节:

  • \(\lambda = 0\) 时,只考虑一步的 TD Error,偏差较大,但方差较小。
  • \(\lambda = 1\) 时,优势估计等价于蒙特卡洛方法,偏差较小,但方差较大。

在 PPO 实现的过程中,我们需要对每个 Token 都计算出 GAE,因此需要反向遍历,用递推的形式逐步计算。实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def GAE(values, rewards, gamma, lambd):
# values: (batch_size, seq_len)
# rewards: (batch_size, seq_len)

suffix_gae = 0
advantages_reversed = []
response_length = rewards.size(1)

# 从最后一个时间步开始向前遍历
for t in reversed(range(response_length)):
nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0
delta = rewards[:, t] + gamma * nextvalues - values[:, t] # 计算 TD error
suffix_gae = delta + gamma * lambd * suffix_gae # 递推计算 GAE
advantages_reversed.append(suffix_gae) # 当前时间步的优势

# advantages: [batch_size, seq_len]
advantages = torch.stack(advantages_reversed[::-1], dim=1) # 将逆序列表反转为正序
return advantages.detach()

PPO

在真实的 PPO 代码中,可能包含策略模型损失函数、价值模型损失函数、语言模型损失函数,用于优化 Actor、Critic 和防止模型遗忘预训练知识。这里我们仅介绍常见的 PolicyLoss 和 ValueLoss。

Policy Loss

策略模型损失函数可以表示为: \[ L^{\mathrm{CLIP}}(\theta) = \mathbb{E}_{t} \left[ \min \left( r_t(\theta) A_t,\; \operatorname{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) A_t \right) \right] \] 用于优化 Actor 模型,为确保策略更新不会偏离旧策略太远,通过计算新旧策略的概率比率,并使用裁剪参数 \(\epsilon\) 与裁剪函数 clamp 限制更新幅度,从而稳定训练过程。

1
2
3
4
5
6
7
8
9
10
def PolicyLoss(log_probs, old_log_probs, advantages)
"""
log_probs: (batch_size, seq_len),表示某个被选中的 action 的对数概率
advantages: (batch_size, seq_len)
"""
ratio = (log_probs - old_log_probs).exp() # 计算概率比率
surr1 = ratio * advantages # 计算未裁剪的目标函数
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages # 计算裁剪后的目标函数
loss = -torch.min(surr1, surr2) # 取两者中的较小值
return loss.mean()

注意:这里的 log_probs 表示采样的 Experience 中选中 Action 的对数概率,经过一层 Log 变换:

1
2
3
4
5
6
7
8
# 模型输出原始 logits (batch_size, seq_len, vocab_size)
logits = actor_model(input_ids)

# 计算概率分布(softmax 后的概率)
probs = F.softmax(logits, dim=-1) # (batch_size, seq_len, vocab_size)

# 提取被采样 token 的 log 概率
log_probs = torch.gather(probs.log(), dim=-1, index=actions.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)

Value Loss

价值模型损失函数可以表示为: \[ L^{\text{Critic}}(\phi) = \mathbb{E}_{(x, y) \sim \mathcal{D}} \left[ \| V_\phi(x) - V_t^{\text{target}} \|^2 \right] \]

其中,\(V_\phi(x)\) 是价值模型对状态 \(x\) 的价值估计,\(V_t^{\text{target}}\) 在 RLHF-PPO 中一般定义为回报值 \(R_t = \hat{A}_{t}^{\mathrm{GAE}(\gamma,\lambda)} + V(s_t)\)

虽然在 PPO 原论文中并未对价值函数引入裁剪,但在一些 PPO 的变体中,为了提高训练的稳定性,也在价值函数的训练中引入了裁剪操作。同样,通过参数 \(\epsilon\) ,约束新策略 \(V_\phi^{new}\) 与旧策略 \(V_\phi^{old}\) 之间的变化幅度。

实现代码如下:

1
2
3
4
5
6
7
8
9
10
def ValueLoss(values, old_values, returns):
"""
values: (batch_size, seq_len)
returns: (batch_size, seq_len)
"""
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - returns) ** 2
surr2 = (values - returns) ** 2 # 未经裁剪的价值损失
loss = torch.max(surr1, surr2) # 取较大的损失(最大值)
return 0.5 * loss.mean()

DPO Loss

通过极大似然估计,BT 模型的优化目标可以转化为二元交叉熵损失: \[ \mathcal{L}_{\text{DPO}}(\theta) = -\mathbb{E}_{(x,y_w,y_l) \sim D} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta \log \frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)} \right) \right] \]

注意,DPO的目标是比较两个完整序列的整体偏好差异,而不是逐个token的生成质量。因此:被选中/拒绝响应的对数概率是对序列所有token的对数概率求和或者取平均,维度为 (batch_size,)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
policy_chosen_logps, policy_rejected_logps = concatenated_forward(
self.policy_model, chosen_ids, c_mask, reject_ids, r_mask, prompt_id_lens
) # 策略模型的概率,(batch_size,)

reference_chosen_logps, reference_rejected_logps = concatenated_forward(
self.ref_model, chosen_ids, c_mask, reject_ids, r_mask, prompt_id_lens
) # 参考模型的概率,(batch_size,)

def DPOLoss(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps):
"""
policy_chosen_logps: (batch_size,)
policy_rejected_logps: (batch_size,)
"""
# 计算策略模型和参考模型的 log 概率比
pi_logratios= policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
logits = pi_logratios - ref_logratios # 构造 logits 差值

losses = -F.logsigmoid(self.beta * logits) # -log σ(beta * (策略比值 - 参考比值))
return losses.mean()

GRPO Loss

Reward 归一化处理

对于每个问题 \(q\) ,首先从旧策略模型 \(\pi_{o_\mathrm{old}}\) 中采样生成一组包含 \(G\) 个输出的结果 \(\{o_1,o_2,\cdots,o_G\}\)。利用奖励模型对这些输出进行评分,得到对应的奖励值 \(r=\{r_1,r_2,\cdots,r_G\}\)。对奖励进行归一化处理,该输出中所有 token 的优势值 \(\hat{A}_{i,t}\) 设为归一化的奖励值,即\(\hat{A}_{i,t}=\tilde{r}_i\quad(\forall t \in o_i).\)

1
2
3
4
5
6
7
rewards = [experience.info["reward"] for experience in experiences]

# 每行对应一个 prompt 的 n 个样本
rewards = torch.cat(rewards).reshape(-1, args.n_samples_per_prompt) # [batch_size, n_samples_per_prompt]
# 计算均值与组标准差
rewards = (rewards - rewards.mean(-1, keepdim=True)) / (rewards.std(-1, keepdim=True) + 1e-9)
rewards = rewards.reshape(-1).chunk(len(experiences))

Loss 计算

GRPO 的核心改进在于消除价值函数,通过组内相对优势估计代替绝对优势评估。其目标函数结合 PPO 的剪切机制与 KL 惩罚项,具体实现可分为以下步骤:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def GRPOLoss(log_probs, old_log_probs, ref_log_probs, advantages, clip_eps=0.2, beta=0.04):
"""
log_probs: (batch_size * G, seq_len) 新策略的对数概率
old_log_probs: (batch_size * G, seq_len) 旧策略的对数概率
ref_log_probs: (batch_size * G, seq_len) 参考模型(如 SFT 模型)的对数概率
advantages: (batch_size * G, seq_len) 归一化后的优势值(每个 token 相同)
"""
# 计算策略比率与剪切损失
ratio = (log_probs - old_log_probs).exp() # (batch*G, seq_len)
surr1 = ratio * advantages # 未裁剪目标
surr2 = ratio.clamp(1 - clip_eps, 1 + clip_eps) * advantages # 裁剪目标
policy_loss = -torch.min(surr1, surr2) # 取较小值防止过激更新

# 计算 KL 散度惩罚项(当前策略 vs 参考模型)
kl_loss = log_probs - ref_log_probs # 近似 KL 散度

total_loss = (policy_loss - beta * kl_loss).mean()
return total_loss

无偏 KL 散度

为了有效地计算策略模型与参考模型之间的 KL 散度,我们可以从对数概率(log_prob)的角度来实现这一过程: \[ \mathrm{KL}\left(\pi_\theta^{\mathrm{RL}}(y \mid x) \parallel \pi^{\mathrm{SFT}}(y \mid x)\right) = \sum_{t} \left( \log \pi_\theta^{\mathrm{RL}}(y_t \mid x, y_{<t}) - \log \pi^{\mathrm{SFT}}(y_t \mid x, y_{<t}) \right) \] 具体步骤如下:

  1. 使用策略模型进行采样:对输入 \(x\) 生成的输出 \(y\)
  2. 计算策略模型的对数概率(log_probs):计算策略模型在每个时间步(Token)的对数概率。
  3. 计算参考模型的对数概率(ref_log_probs):对相同的输入输出,使用参考模型计算逐 Token 的对数概率。
  4. 计算 KL 散度:通过对数概率之差累加得到策略模型与参考模型之间的 KL 散度。
  5. 整合奖励:Token 级的 KL 散度与奖励模型在输出 \(y\) 上序列级 Reward 组合,得到 \(r_{\text {total}}\)
1
2
3
4
5
def compute_kl(log_probs, log_probs_base):
"""
log_probs: (batch_size, seq_len)
"""
kl = log_probs.float() - log_probs_base.float() # log(pi) - log(pi_ref)

在实际使用的时候,KL 散度还有一种非负、无偏、低方差的估计方式:

\[ \mathrm{KL}\left(\pi_\theta^{\mathrm{RL}}(y \mid x) \parallel \pi^{\mathrm{SFT}}(y \mid x)\right) = \frac{\pi^{\mathrm{SFT}}(y_t \mid x, y_{<t})}{\pi_\theta^{\mathrm{RL}}(y_t \mid x, y_{<t})} - \log \frac{\pi^{\mathrm{SFT}}(y_t \mid x, y_{<t})}{\pi_\theta^{\mathrm{RL}}(y_t \mid x, y_{<t})} - 1 \]

1
2
3
4
5
6
7
def compute_approx_kl(log_probs, log_probs_base):
"""
log_probs: (batch_size, seq_len)
"""
log_ratio = log_probs.float() - log_probs_base.float() # 保持原来的计算
log_ratio = -log_ratio # 注意公式里是反过来的
kl = log_ratio.exp() - 1 - log_ratio

手撕经典算法 #5 RLHF 篇
https://hwcoder.top/Manual-Coding-5
作者
Wei He
发布于
2025年3月20日
许可协议