Implementation:LaurentMazare Tch rs A2C Agent
| Knowledge Sources | |
|---|---|
| Domains | Deep Learning, Reinforcement Learning, Game AI |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements the Advantage Actor-Critic (A2C) reinforcement learning algorithm for playing Atari games, using a convolutional neural network with frame stacking.
Description
This module implements A2C, a synchronous variant of the Asynchronous Advantage Actor-Critic (A3C) algorithm introduced by DeepMind. It is designed for the SpaceInvadersNoFrameskip-v4 Atari environment with 16 parallel processes.
Model architecture: A convolutional neural network takes stacked frames (4 grayscale 84x84 images) as input:
- Conv2D(4 -> 32, kernel=8, stride=4) + ReLU
- Conv2D(32 -> 64, kernel=4, stride=2) + ReLU
- Conv2D(64 -> 64, kernel=3, stride=1) + ReLU + Flatten
- Linear(3136 -> 512) + ReLU
- Two heads: critic (Linear 512 -> 1) and actor (Linear 512 -> nact)
The model returns a boxed closure Box<dyn Fn(&Tensor) -> (Tensor, Tensor)> producing (value, action_logits).
FrameStack: Maintains a rolling buffer of the last 4 frames per process. When an episode ends (indicated by the done mask), the frame stack is zeroed before adding the new observation.
Training loop: Runs for 1,000,000 updates, each collecting 5 steps from 16 parallel environments:
- Collect actions via multinomial sampling from the softmax policy.
- Compute N-step returns with discount factor 0.99 using the critic's bootstrap value.
- Calculate value loss (MSE of advantages), action loss (policy gradient weighted by detached advantages), and entropy bonus (weighted by 0.01).
- Combined loss:
0.5 * value_loss + action_loss - 0.01 * entropy. - Update with Adam optimizer (lr=1e-4) and gradient clipping at 0.5.
Model checkpoints are saved every 10,000 updates. A separate sample function loads trained weights and renders trajectories.
Usage
Use this example to learn how to implement actor-critic reinforcement learning with tch-rs, including parallel environment interaction via vectorized Gym environments, frame stacking, and N-step return estimation.
Code Reference
Source Location
- Repository: LaurentMazare_Tch_rs
- File: examples/reinforcement-learning/a2c.rs
- Lines: 1-171
Signature
type Model = Box<dyn Fn(&Tensor) -> (Tensor, Tensor)>;
fn model(p: &nn::Path, nact: i64) -> Model
struct FrameStack {
data: Tensor,
nprocs: i64,
nstack: i64,
}
impl FrameStack {
fn new(nprocs: i64, nstack: i64) -> FrameStack
fn update<'a>(&'a mut self, img: &Tensor, masks: Option<&Tensor>) -> &'a Tensor
}
pub fn train() -> cpython::PyResult<()>
pub fn sample<T: AsRef<std::path::Path>>(weight_file: T) -> cpython::PyResult<()>
Import
// Module within the reinforcement-learning example.
use super::vec_gym_env::VecGymEnv;
use tch::kind::{FLOAT_CPU, INT64_CPU};
use tch::{nn, nn::OptimizerConfig, Kind::Float, Tensor};
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| ENV_NAME | Constant (String) | Yes | OpenAI Gym environment name ("SpaceInvadersNoFrameskip-v4"). |
| NPROCS | Constant (i64) | Yes | Number of parallel environment processes (16). |
| NSTEPS | Constant (i64) | Yes | Number of steps per update (5). |
| NSTACK | Constant (i64) | Yes | Number of frames to stack (4). |
| weight_file | Path (for sample) | For sampling | Path to saved model weights (.ot file). |
Outputs
| Name | Type | Description |
|---|---|---|
| stdout | Text | Training progress: update index, total episodes, and average reward per episode (every 500 updates). |
| a2cN.ot | File (weights) | Model checkpoints saved every 10,000 updates. |
| /dev/shm frames | Image files | Observation frames saved during sampling. |
Usage Examples
use super::vec_gym_env::VecGymEnv;
use tch::kind::{FLOAT_CPU, INT64_CPU};
use tch::{nn, nn::OptimizerConfig, Kind::Float, Tensor};
const ENV_NAME: &str = "SpaceInvadersNoFrameskip-v4";
const NPROCS: i64 = 16;
const NSTEPS: i64 = 5;
const NSTACK: i64 = 4;
// Build the actor-critic model
let vs = nn::VarStore::new(tch::Device::cuda_if_available());
let model = model(&vs.root(), env.action_space());
let mut opt = nn::Adam::default().build(&vs, 1e-4).unwrap();
// Collect rollout data
let mut frame_stack = FrameStack::new(NPROCS, NSTACK);
let _ = frame_stack.update(&env.reset()?, None);
for s in 0..NSTEPS {
let (critic, actor) = tch::no_grad(|| model(&s_states.get(s)));
let probs = actor.softmax(-1, Float);
let actions = probs.multinomial(1, true).squeeze_dim(-1);
let step = env.step(Vec::<i64>::try_from(&actions).unwrap())?;
// ... store rewards, masks, observations ...
}
// Compute loss and update
let loss = value_loss * 0.5 + action_loss - dist_entropy * 0.01;
opt.backward_step_clip(&loss, 0.5);