Implementation:LaurentMazare Tch rs Policy Gradient
| Knowledge Sources | |
|---|---|
| Domains | Deep Learning, Reinforcement Learning, Policy Optimization |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements the REINFORCE policy gradient algorithm for the CartPole-v0 environment, using reward-to-go baseline and a simple two-layer neural network policy.
Description
This module implements the REINFORCE policy gradient algorithm, adapted from the OpenAI Spinning Up series. It trains an agent to solve the CartPole-v0 balancing task.
Policy network: A simple feedforward neural network with:
- Linear(input_dim -> 32) + Tanh activation
- Linear(32 -> nact)
The input dimension is derived from the environment's observation space. Actions are sampled by applying softmax to the logits and using multinomial sampling inside a no_grad block.
Reward-to-go computation: The accumulate_rewards function computes discounted cumulative rewards by iterating backwards through the steps. When an episode boundary (is_done) is encountered, the accumulator resets to zero. This provides a reward-to-go baseline that reduces variance compared to using total episode reward.
Training loop: For each of 50 epochs:
- Rollout phase: Collect at least 5000 steps of experience by running the current policy in the environment. Episodes are collected until both a done signal is received and the step count exceeds 5000.
- Update phase: Compute the policy gradient loss as the negative mean of (reward-to-go * log probability of the taken action). The action mask is created using
scatter_valueto select the log probability of the chosen action from the full logit vector. - The Adam optimizer (learning rate 0.01) performs a single backward step on this loss.
- Average reward per episode is printed for monitoring.
Usage
Use this example to learn the fundamentals of policy gradient reinforcement learning with tch-rs, including rollout collection, reward-to-go computation, and the REINFORCE algorithm update rule.
Code Reference
Source Location
- Repository: LaurentMazare_Tch_rs
- File: examples/reinforcement-learning/policy_gradient.rs
- Lines: 1-84
Signature
fn model(p: &nn::Path, input_shape: &[i64], nact: i64) -> impl nn::Module
fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64>
pub fn run() -> cpython::PyResult<()>
Import
// Module within the reinforcement-learning example.
use super::gym_env::{GymEnv, Step};
use tch::{nn, nn::OptimizerConfig, Kind::Float, Tensor};
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| (environment) | GymEnv | Yes | CartPole-v0 environment, created internally via GymEnv::new("CartPole-v0").
|
Outputs
| Name | Type | Description |
|---|---|---|
| stdout | Text | Per-epoch statistics: epoch index, number of episodes, and average reward per episode. |
Usage Examples
use super::gym_env::{GymEnv, Step};
use tch::{nn, nn::OptimizerConfig, Kind::Float, Tensor};
// Create environment and policy model
let env = GymEnv::new("CartPole-v0")?;
let vs = nn::VarStore::new(tch::Device::Cpu);
let model = model(&vs.root(), env.observation_space(), env.action_space());
let mut opt = nn::Adam::default().build(&vs, 1e-2).unwrap();
// Rollout: collect experience
let mut obs = env.reset()?;
let mut steps: Vec<Step<i64>> = vec![];
loop {
let action = tch::no_grad(|| {
obs.unsqueeze(0).apply(&model).softmax(1, Float).multinomial(1, true)
});
let action = i64::try_from(action).unwrap();
let step = env.step(action)?;
steps.push(step.copy_with_obs(&obs));
obs = if step.is_done { env.reset()? } else { step.obs };
if step.is_done && steps.len() > 5000 {
break;
}
}
// Compute reward-to-go and policy gradient loss
let rewards = accumulate_rewards(&steps);
let rewards = Tensor::from_slice(&rewards).to_kind(Float);
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
let actions = Tensor::from_slice(&actions).unsqueeze(1);
let batch_size = steps.len() as i64;
let action_mask =
Tensor::zeros([batch_size, 2], tch::kind::FLOAT_CPU).scatter_value(1, &actions, 1.0);
let obs: Vec<Tensor> = steps.into_iter().map(|s| s.obs).collect();
let logits = Tensor::stack(&obs, 0).apply(&model);
let log_probs =
(action_mask * logits.log_softmax(1, Float)).sum_dim_intlist(1, false, Float);
let loss = -(rewards * log_probs).mean(Float);
opt.backward_step(&loss);