1. Overview of PPO Implementation
The simplified code for the RLHF with PPO algorithm in TRL is as follows:
| |
2. Theory
To express the PPO algorithm explicitly, I will ignore the kl-penalty term and clip operation for simplicity.
2.1. Policy Loss
2.1.1 Policy Gradient
The optimization objective of reinforcement learning is to maximize the expected return:
$$\max_{\pi_\theta}\mathcal{J}(\pi_\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau)]= \sum_\tau R(\tau)P(\tau|\pi_\theta) \tag{Eq 1}$$where $\tau$ is a trajectory sampled from the policy $\pi_\theta$, and $R(\tau)$ is the return of the trajectory.
Then the policy gradient can be derived as follows(details can be found in the [1]):
$$ \begin{align*} \nabla_{\theta}\mathcal{J}(\pi_\theta) &= \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau)\nabla log(P(\tau|\pi_\theta))] \\ & = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau) \sum_{t=0}^{T_n-1}\nabla log(\pi_\theta(a_t|s_t))] \tag{Eq 2} \end{align*} $$where $T_n$ is the length of the trajectory, and $s$ is the state at time step $t$.
2.1.2. Improvement of $R(\tau)$ - GAE
As we consider Eq 2, the return $R(\tau)$ is a sum of rewards on the trajectory, but the $\pi_\theta(a|s)$ is a probability distribution on single step. We may think this is not a good choice to use the sum of rewards as the return. In fact there are many ways to improve the return $R(\tau)$, but the most common one is Generalized Advantage Estimation (GAE).
$$ \begin{align*} &\delta_\phi(s_t,a_t) = r_t +\gamma V_\phi(s_{t+1}) - V_\phi(s_t) \\ &A_\phi(s_t,a_t)=\sum_{l=0}^{\infty}(\gamma \lambda)^l \delta_\phi(s_{t+l},a_{t+l}) \tag{Eq 3} \end{align*} $$where $A_\phi(s_t,a_t)$ is the advantage function, $r_t$ is the reward at time step $t$, $\gamma$ is the discount factor, and $V_\phi(s)$ is the value function.
The new policy gradient can be written as:
$$ \nabla_{\theta}\mathcal{J}(\pi_\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[\sum_{t=0}^{T_n-1} A_\phi(s_t,a_t)\nabla log(\pi_\theta(a_t|s_t))] \tag{Eq 4} $$According to the gradient Eq 4, we can get the new optimization objective:
$$ \begin{align*} \max_{\pi_\theta}\mathcal{J}(\pi_\theta) &= \mathbb{E}_{\tau \sim \pi_\theta}[\sum_{t=0}^{T_n-1} A_\phi(s_t,a_t) log(\pi_\theta(a_t|s_t))] \\ & \approx \frac{1}{N}\sum_{n=1}^{N}\sum_{t=0}^{T_n-1} A_\phi(s_t,a_t) log(\pi_\theta(a_t|s_t)) \\ \Leftrightarrow \max_{\pi_\theta}\mathcal{J}(\pi_\theta) & = \frac{1}{\sum_{n=1}^{N} T_n} \sum_{n=1}^{N}\sum_{t=0}^{T_n-1}A_\phi(s_t,a_t) log(\pi_\theta(a_t|s_t)) \\ &= \mathbb{E}_t[A_\phi(s_t,a_t) log(\pi_\theta(a_t|s_t))] \tag{Eq 5} \end{align*} $$Because the solution is equal between the expectation of the trajectory $\tau$ and the expectation of the time step $t$, we can derive line 2 to line3 in Eq 5.
2.1.3. Importance Sampling
In the above equation, we can see that the policy gradient is computed by the advantage function and the log probability of the action taken by the policy. However, in practice, we often have a old policy $\pi_{\text{old}}(a|s)$, which is used to generate the trajectory $\tau$. We will use these trajectories to update the policy $\pi_\theta(a|s)$ for ppo_epochs times.
The Policy Gradient can be rewritten as:
$$ \begin{align*} \nabla_{\theta}\mathcal{J}(\pi_\theta) &= \mathbb{E}_t[A_\phi(s_t,a_t)\nabla log(\pi_\theta(a_t|s_t))]\\ &=\mathop{\mathbb{E}_t}\limits_{\tau \sim \pi_{\text{old}}}[ \frac{\pi_\theta(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)} A_\phi(s_t,a_t)\nabla log(\pi_\theta(a_t|s_t)) ] \tag{Eq 6} \end{align*} $$At last, we can get the final optimization objective:
$$ \mathcal{J}(\pi_\theta) = \mathop{\mathbb{E}_t}\limits_{\tau \sim \pi_{\text{old}}}[A_\phi(s_t,a_t) \frac{\pi_\theta(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)}] \tag{Eq 7} $$2.2 Critic Loss
$$ \mathcal{L}_{\text{critic}} = \mathbb{E}_t[(V_\phi(s_t) - A_\phi(s_t,a_t)-V_{\text{old}}(s_t))^2] \tag{Eq 8} $$References
- 猛猿. 人人都能看懂的RL-PPO理论知识 .
- TRL: PPO Trainer Implementation .