Training a CartPole Agent with Proximal Policy Optimization (PPO)
1. Introduction
Reinforcement Learning (RL) has become one of the most exciting areas of machine learning, with applications ranging from robotics to games like Go and StarCraft. A classic benchmark for testing RL algorithms is the CartPole environment [3], where an agent must balance a pole on a moving cart by applying forces left or right.
In this blog, we'll walk through training a CartPole agent using Proximal Policy Optimization (PPO) [1] — one of the most popular and effective policy gradient algorithms. We'll cover the mathematical background, explain why PPO works, and then dive into the implementation details from the code.
2. Reinforcement Learning Basics
At its core, RL is modeled as a Markov Decision Process (MDP), defined by:
- States \(s_t\): the environment’s representation (Cart position, velocity, pole angle, etc.).
- Actions \(a_t\): discrete choices (push cart left or right).
- Rewards \(r_t\): scalar feedback (here +1 for every timestep the pole stays upright).
The agent follows a policy \(\pi(a|s)\), mapping states to action probabilities. In policy gradient methods, we directly optimize this policy by maximizing the expected cumulative reward:
\[ J(\theta) = \mathbb{E}_{\pi_\theta} \Bigg[ \sum_t \gamma^t r_t \Bigg] \]
3. Background: TRPO
Vanilla policy gradient suffers from high variance and unstable updates. Trust Region Policy Optimization (TRPO) [2] addressed this by ensuring each update stays close to the old policy using a KL-divergence constraint:
\[ \max_\theta \; \mathbb{E}\left[ \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} \hat{A}_t \right] \quad \text{s.t.} \quad D_{KL}(\pi_{\theta_{old}} \;||\; \pi_\theta) \leq \delta \]
TRPO guarantees monotonic improvement, but requires expensive second-order optimization. This makes it impractical for large-scale problems.
4. PPO: The Key Idea
PPO simplifies TRPO by replacing the hard KL constraint with a clipped surrogate objective. Instead of enforcing a strict trust region, it limits how much the probability ratio can change:
\[ L^{CLIP}(\theta) = \mathbb{E}\Big[ \min\big(r_t(\theta)\hat{A}_t,\;\; \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_t \big) \Big] \]
where \(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}\) is the importance sampling ratio.
PPO also adds:
- Value loss for better baseline estimation.
- Entropy bonus to encourage exploration.
This combination makes PPO both stable and sample-efficient, which is why it has become the de facto standard in RL research and applications.
5. Math Foundations
Let’s summarize the full PPO objective used in practice:
\[ L(\theta) = \mathbb{E}\Big[ L^{CLIP}(\theta) - c_1 (V_\theta(s_t) - R_t)^2 + c_2 H[\pi_\theta](s_t) \Big] \]
- \(L^{CLIP}\): clipped surrogate policy loss.
- \(V_\theta(s_t)\): value function prediction.
- \(R_t\): discounted return.
- \(H[\pi_\theta]\): policy entropy (encourages exploration).
For advantage estimation, one can use Generalized Advantage Estimation (GAE), which trades off bias and variance. This implementation uses a simpler reward-to-go approach.
6. CartPole Setup
The CartPole environment (CartPole-v1
) has:
- Observations: [cart position, velocity, pole angle, angular velocity] (4 floats).
- Actions: {0: push left, 1: push right}.
- Reward: +1 for every timestep the pole remains balanced.
- Termination: pole falls too far or cart moves out of bounds.
It’s a simple yet effective testbed for experimenting with RL algorithms.
7. Implementation Walkthrough
Policy & Value Networks
We use a shared neural network with two heads (an actor-critic network):
- Policy head: acts as teh actor. Outputs action probabilities via softmax.
- Value head: acts as the critic to provide feedback to calculate the value and actor follows this. Outputs a scalar estimate of state value.
class ActorCriticNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=64):
super(ActorCriticNetwork, self).__init__()
# Shared feature extractor
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh()
)
# Actor head (policy)
self.actor = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
# Critic head (value function)
self.critic = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state):
features = self.shared(state)
action_probs = self.actor(features)
state_value = self.critic(features)
return action_probs, state_value
Rollout Collection & Batching
The agent collects rollout_len
steps of experience, storing states, actions,
rewards, and log-probabilities. After a rollout, we compute discounted returns:
def compute_returns(rewards, dones, gamma, next_value):
returns = []
R = next_value
for reward, done in zip(reversed(rewards), reversed(dones)):
if done: R = 0
R = reward + gamma * R
returns.insert(0, R)
return returns
Loss Functions
- Policy loss: PPO clipped surrogate objective.
- Value loss: MSE between predicted and actual returns.
- Entropy loss: encourages exploration.
PPO Update Rule
def update(self):
"""Update policy using PPO"""
# Compute last state value for bootstrapping
last_state = torch.FloatTensor(self.states[-1]).unsqueeze(0)
with torch.no_grad():
_, last_value = self.policy(last_state)
last_value = last_value.item() if not self.dones[-1] else 0.0
# Compute returns and advantages
returns = self.compute_returns(last_value)
returns = torch.FloatTensor(returns)
# Convert buffers to tensors
old_states = torch.FloatTensor(np.array(self.states))
old_actions = torch.LongTensor(self.actions)
old_log_probs = torch.FloatTensor(self.log_probs)
old_values = torch.FloatTensor(self.values)
# Normalize advantages
advantages = returns - old_values
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# PPO update for multiple epochs
for epoch in range(self.epochs):
indices = np.arange(len(self.states))
np.random.shuffle(indices)
for start in range(0, len(self.states), self.batch_size):
end = start + self.batch_size
batch_indices = indices[start:end]
batch_states = old_states[batch_indices]
batch_actions = old_actions[batch_indices]
batch_old_log_probs = old_log_probs[batch_indices]
batch_advantages = advantages[batch_indices]
batch_returns = returns[batch_indices]
# Get current policy predictions
action_probs, state_values = self.policy(batch_states)
dist = Categorical(action_probs)
new_log_probs = dist.log_prob(batch_actions)
entropy = dist.entropy().mean()
# Importance sampling ratio
ratio = torch.exp(new_log_probs - batch_old_log_probs)
# Clipped surrogate loss
surr1 = ratio * batch_advantages
surr2 = torch.clamp(ratio, 1.0 - self.epsilon, 1.0 + self.epsilon) * batch_advantages
policy_loss = -torch.min(surr1, surr2).mean()
# Value loss
value_loss = nn.MSELoss()(state_values.squeeze(), batch_returns)
# Total loss
loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
# Optimize
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
self.optimizer.step()
Training Loop
We iterate over episodes, collect rollouts, split into minibatches, and optimize with Adam. Hyperparameters include:
- \(\gamma = 0.99\) (discount factor)
- \(\epsilon = 0.2\) (clip parameter)
- Learning rate = 3e-4
- Rollout length = 2048
- Batch epochs = 10
def train_ppo(episodes=1000, max_steps=500, update_frequency=2048):
"""Train PPO agent on CartPole environment"""
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPOAgent(state_dim, action_dim)
episode_rewards = []
running_reward = 0
step_count = 0
for episode in tqdm(range(episodes), desc="Training PPO"):
state, _ = env.reset()
episode_reward = 0
done = False
while not done:
action, log_prob, value = agent.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.store_transition(state, action, reward, log_prob, value, done)
state = next_state
episode_reward += reward
step_count += 1
# Update policy every update_frequency steps
if step_count % update_frequency == 0:
agent.update()
episode_rewards.append(episode_reward)
running_reward = 0.05 * episode_reward + (1 - 0.05) * running_reward
if len(episode_rewards) >= 100 and np.mean(episode_rewards[-100:]) >= 475:
print(f"Solved at episode {episode}! Average reward: {np.mean(episode_rewards[-100:]):.2f}")
break
env.close()
return agent, episode_rewards
8. Results
After training, the PPO agent achieves the maximum score: 500 reward across all 5 test runs. This means the agent successfully balances the pole for the full 500 steps per episode.
9. References
- Schulman, J., Wolski, F., Dhariwal, P., Radford, A., & Klimov, O. (2017). Proximal Policy Optimization Algorithms. arXiv preprint arXiv:1707.06347. https://arxiv.org/abs/1707.06347
- Schulman, J., Levine, S., Abbeel, P., Jordan, M., & Moritz, P. (2015). Trust Region Policy Optimization. In International Conference on Machine Learning (pp. 1889-1897). PMLR. https://arxiv.org/abs/1502.05477
- Towers, M., Terry, J. K., Kwiatkowski, A., Balis, J. U., Cola, G. d., Deleu, T., Goulão, M., Kallinteris, A., KG, A., Krimmel, M., Perez-Vicente, R., Pierré, A., Schulhoff, S., Tai, J. J., Shen, A. T. J., & Younis, O. G. (2023). Gymnasium. Zenodo. https://zenodo.org/record/8127025
10. Conclusion
PPO combines the stability of TRPO with the simplicity of vanilla policy gradient methods, making it a powerful algorithm for continuous control problems. Even on a small benchmark like CartPole, we can see how PPO ensures stable and reliable training.