最近套磁到了港科的一个 RA 岗位,组里在做 Agent 强化学习相关的研究,按照学长的建议先把 PPO 算法系统学习一遍,于是有了这篇笔记。
在强化学习中,常用符号约定如下:
- at:t 时刻 Agent 采取的动作(action);
- st:t 时刻 Agent 所处的状态(state);
- π:Agent 的策略函数(policy),输入状态,输出每个动作的概率分布 π(at∣st);
- rt:t 时刻 Agent 采取动作后获得的奖励(reward);
- τ:一条轨迹(trajectory),即一段状态-动作序列;
- Episode:一次完整的交互过程,从环境初始化开始,到达到终止状态为止;
- Rollout:按当前策略实际"跑出来"的一段轨迹数据,不一定是完整的一局,也可以只是其中的一段。
(s0,a0,s1,a1…sT)
环境的状态转移满足 st+1=f(st,at)(确定性环境)或 st+1∼P(⋅∣st,at)(随机性环境)。
Return(回报)指从当前时间步到 episode 结束所获得奖励的累积和(或带折扣的累积和)。
我们要训练一个策略网络 πθ,使其在所有可能的状态下做出动作后,期望回报最大化:
E(R(τ))τ∼Pθ(τ)=τ∑R(τ)Pθ(τ)
∇E(R(τ))τ∼Pθ(τ)=∇τ∑R(τ)Pθ(τ)=τ∑R(τ)∇Pθ(τ)=τ∑R(τ)∇Pθ(τ)⋅Pθ(τ)Pθ(τ)=τ∑Pθ(τ)R(τ)Pθ(τ)∇Pθ(τ)≈N1n=1∑NR(τn)Pθ(τn)∇pθ(τn)=N1n=1∑NR(τn)∇logPθ(τn)=N1n=1∑NR(τn)∇logt=1∏TnPθ(ant∣snt)=N1n=1∑NR(τn)t=1∑Tn∇logPθ(ant∣snt)=N1n=1∑Nt=1∑TnR(τn)∇logPθ(ant∣snt)Loss=−N1n=1∑Nt=1∑TnR(τn)logPθ(ant∣snt)
R(τn)=t=t′∑Tnγt−t′rt′n=Rtn
为了衡量某个动作"相对好坏",我们让奖励 Rt 减去一个基准值 B(st),从而降低梯度估计的方差:
N1n=1∑Nt=1∑Tn(Rtn−B(stn))∇logPθ(atn∣stn)
然而 Rtn 仅来自一次随机采样,方差很大、训练不稳定。一个自然的改进是引入价值函数:用 Qθ(s,a) 表示在状态 s 下采取动作 a 的期望回报,即 动作价值函数;用 Vθ(s) 表示在状态 s 下的期望回报,即 状态价值函数。
定义 优势函数(Advantage Function) Aθ(s,a)=Qθ(s,a)−Vθ(s),它衡量在状态 s 下,动作 a 相对于"该状态下的平均水平"有多大优势。
将上述策略梯度公式改写为:
N1n=1∑Nt=1∑TnAθ(stn,atn)∇logPθ(atn∣stn)
利用 Bellman 关系,可以把动作价值函数与状态价值函数联系起来:
Qθ(st,a)=rt+γVθ(st+1)Aθ(st,a)=rt+γVθ(st+1)−Vθ(st)Vθ(st+1)≈rt+1+γVθ(st+2)经过这一步代换,Aθ(s,a) 的估计只依赖一个价值函数 Vθ,公式整体的复杂度也得以降低。
那么应该向后展开(rollout)多少步再用 Vθ 截断呢?展开越多步,估计就越接近真实回报,偏差越小,但方差也会越大;反之展开越少,偏差大、方差小。这就是经典的 bias-variance trade-off:
Aθ1(st,a)=rt+γVθ(st+1)−Vθ(st)Aθ2(st,a)=rt+γrt+1+γ2Vθ(st+2)−Vθ(st)Aθ3(st,a)=rt+γ∗rt+1+γ2∗rt+2+γ3Vθ(st+3)−Vθ(st)⋮AθT(st,a)=rt+γ∗rt+1+γ2∗rt+2+γ3∗rt+3+⋯+γT∗rT−Vθ(st)为了让公式更简洁,引入一个中间量 δtV,表示 第 t 步采取该动作所带来的 TD 残差(temporal-difference error):
δtV=rt+γ∗Vθ(st+1)−Vθ(st)δt+1V=rt+1+γ∗Vθ(st+2)−Vθ(st+1)Aθ1(st,a)=δtVAθ2(st,a)=δtV+γδt+1VAθ3(st,a)=δtV+γδt+1V+γ2δt+2V⋮那么"展开几步"这件事到底该怎么选?GAE 给出的答案是:小孩子才做选择,全都要!
GAE 把不同展开步数的优势估计 Ak 用一个权重 λ 做指数加权平均,再化简成一个简洁的求和形式:
AθGAE(st,a)=(1−λ)(Aθ1+λ∗Aθ2+λ2Aθ3+⋯)λ=0.9:AθGAE=0.1Aθ1+0.09Aθ2+0.081Aθ3+⋯=(1−λ)(δtV+λ∗(δtV+γδt+1V)+λ2(δtV+γδt+1V+γ2δt+2V)+⋯)=(1−λ)(δtV(1+λ+λ2+⋯)+γδt+1V∗(λ+λ2+⋯)+⋯)=(1−λ)(δtV1−λ1+γδt+1V1−λλ+⋯)=b=0∑∞(γλ)bδt+bV通过等比数列求和化简,最终得到 GAE 的紧凑形式。
GAE 优势函数本质上是在 λ→0(高偏差、低方差)与 λ→1(低偏差、高方差)之间做插值,从而平衡 bias 与 variance。
整理一下,到这里我们得到了三个关键表达式:
δtV=rt+γVθ(st+1)−Vθ(st)AθGAE(st,a)=b=0∑∞(γλ)bδt+bVN1n=1∑Nt=1∑TnAθGAE(snt,ant)∇logPθ(ant∣snt)这里的状态价值函数 Vθ 一般用一个神经网络来拟合(即 critic 网络),可以与策略网络共用主干参数,仅在最后一层分叉为两个 head。
在经典的强化学习训练范式里,我们通常一边采集数据、一边更新模型,采过的数据用一次就丢掉——这种做法被称为 on-policy。问题在于,强化学习的环境交互成本往往很高,这样"用一次就扔"显然非常浪费。如果我们能让当前策略 πθ 复用 旧策略 πθ′ 采集的数据进行训练(即 off-policy),训练效率就能显著提升。要实现这一点,关键就是 重要性采样(Importance Sampling)。
我们可以把"f(x) 在分布 p 下的期望"改写为"f(x)⋅q(x)p(x) 在另一分布 q(proposal distribution)下的期望",这样就能用从 q 采样的数据来估计原本在 p 下的期望:
E(f(x))x∼p(x)=x∑f(x)⋅p(x)=x∑f(x)⋅p(x)q(x)q(x)=x∑f(x)q(x)p(x)⋅q(x)=E(f(x)q(x)p(x))x∼q(x)≈N1n=1∑Nf(xn)q(xn)p(xn),xn∼q(x)利用重要性采样,我们就可以把 on-policy 的梯度公式改写为可以复用旧数据的 off-policy 形式。
记 θ′ 为 采集数据时使用的旧策略,θ 为 当前要优化的策略;优势 Aθ′GAE 由旧策略下的价值网络估计而来。结合恒等式 ∇logf(x)=f(x)∇f(x),可以把策略梯度写成包含 重要性采样比 Pθ′Pθ 的形式:
∇logf(x)=f(x)∇f(x)N1n=1∑Nt=1∑TnAθGAE(snt,ant)∇logPθ(ant∣snt)=N1n=1∑Nt=1∑TnAθ′GAE(snt,ant)Pθ′(ant∣snt)Pθ(ant∣snt)∇logPθ(ant∣snt)=N1n=1∑Nt=1∑TnAθ′GAE(snt,ant)Pθ′(ant∣snt)Pθ(ant∣snt)Pθ(ant∣snt)∇Pθ(ant∣snt)=N1n=1∑Nt=1∑TnAθ′GAE(snt,ant)Pθ′(ant∣snt)∇Pθ(ant∣snt)对应地,将期望最大化转换为损失最小化(取负号):
Loss=−N1n=1∑Nt=1∑TnAθ′GAE(snt,ant)Pθ′(ant∣snt)Pθ(ant∣snt)这里有一个 隐含的前提:重要性采样要求 θ 与 θ′ 不能差太多,否则比值 Pθ′Pθ 的方差会爆炸,估计就会失真。换句话说,我们需要给"新旧策略的差距"施加一个约束。PPO 给出了两种约束方式:
如何让训练策略与参考策略不至于偏离太远?最直观的做法是给目标函数加一个 KL 散度惩罚项。KL 散度衡量两个分布的差异:差异越小,KL 越接近 0;差异越大,KL 越大。我们用一个权重 β 来控制这一惩罚的强度(实际中 β 还会做自适应调整):
Lossppo=−N1n=1∑Nt=1∑TnAθ′GAE(snt,ant)Pθ′(ant∣snt)Pθ(ant∣snt)+βKL(Pθ,Pθ′)PPO 还有一种更常用的实现,用 截断(clip)重要性采样比 来代替 KL 惩罚,同样起到限制新旧策略偏差的作用。它的目标函数由两部分组成(下方红色与蓝色项),最终取两者的 较小值:
- 红色部分:原始的 surrogate 目标 ρ⋅A,其中 ρ=Pθ′(a∣s)Pθ(a∣s) 是重要性采样比;
- 蓝色部分:把 ρ 截断到 [1−ϵ,1+ϵ] 区间内之后再乘以 A;当 ρ 落在区间内时返回原值,落在外面则返回最近的边界值。
对二者取 min 的目的是:当一个动作有正向优势 A>0 时,限制 ρ 不会被推得过高(避免步子迈太大);当 A<0 时,限制 ρ 不会被压得过低。这样既保证了"敢于改进策略",又避免了"一次更新偏离太多"。
Lossppo2=−N1n=1∑Nt=1∑Tnmin(Aθ′GAE(snt,ant)Pθ′(ant∣snt)Pθ(ant∣snt),clip(Pθ′(ant∣snt)Pθ(ant∣snt),1−ϵ,1+ϵ)Aθ′GAE(snt,ant))PPO-Clip 由于实现简单、效果稳定,是当前应用最广泛的版本,也是 RLHF/GRPO 等大模型对齐算法的基础。
零基础学习强化学习算法:ppo