Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:LaurentMazare Tch rs A2C Agent

From Leeroopedia


Knowledge Sources
Domains Deep Learning, Reinforcement Learning, Game AI
Last Updated 2026-02-08 00:00 GMT

Overview

Implements the Advantage Actor-Critic (A2C) reinforcement learning algorithm for playing Atari games, using a convolutional neural network with frame stacking.

Description

This module implements A2C, a synchronous variant of the Asynchronous Advantage Actor-Critic (A3C) algorithm introduced by DeepMind. It is designed for the SpaceInvadersNoFrameskip-v4 Atari environment with 16 parallel processes.

Model architecture: A convolutional neural network takes stacked frames (4 grayscale 84x84 images) as input:

  • Conv2D(4 -> 32, kernel=8, stride=4) + ReLU
  • Conv2D(32 -> 64, kernel=4, stride=2) + ReLU
  • Conv2D(64 -> 64, kernel=3, stride=1) + ReLU + Flatten
  • Linear(3136 -> 512) + ReLU
  • Two heads: critic (Linear 512 -> 1) and actor (Linear 512 -> nact)

The model returns a boxed closure Box<dyn Fn(&Tensor) -> (Tensor, Tensor)> producing (value, action_logits).

FrameStack: Maintains a rolling buffer of the last 4 frames per process. When an episode ends (indicated by the done mask), the frame stack is zeroed before adding the new observation.

Training loop: Runs for 1,000,000 updates, each collecting 5 steps from 16 parallel environments:

  1. Collect actions via multinomial sampling from the softmax policy.
  2. Compute N-step returns with discount factor 0.99 using the critic's bootstrap value.
  3. Calculate value loss (MSE of advantages), action loss (policy gradient weighted by detached advantages), and entropy bonus (weighted by 0.01).
  4. Combined loss: 0.5 * value_loss + action_loss - 0.01 * entropy.
  5. Update with Adam optimizer (lr=1e-4) and gradient clipping at 0.5.

Model checkpoints are saved every 10,000 updates. A separate sample function loads trained weights and renders trajectories.

Usage

Use this example to learn how to implement actor-critic reinforcement learning with tch-rs, including parallel environment interaction via vectorized Gym environments, frame stacking, and N-step return estimation.

Code Reference

Source Location

Signature

type Model = Box<dyn Fn(&Tensor) -> (Tensor, Tensor)>;

fn model(p: &nn::Path, nact: i64) -> Model

struct FrameStack {
    data: Tensor,
    nprocs: i64,
    nstack: i64,
}

impl FrameStack {
    fn new(nprocs: i64, nstack: i64) -> FrameStack
    fn update<'a>(&'a mut self, img: &Tensor, masks: Option<&Tensor>) -> &'a Tensor
}

pub fn train() -> cpython::PyResult<()>

pub fn sample<T: AsRef<std::path::Path>>(weight_file: T) -> cpython::PyResult<()>

Import

// Module within the reinforcement-learning example.
use super::vec_gym_env::VecGymEnv;
use tch::kind::{FLOAT_CPU, INT64_CPU};
use tch::{nn, nn::OptimizerConfig, Kind::Float, Tensor};

I/O Contract

Inputs

Name Type Required Description
ENV_NAME Constant (String) Yes OpenAI Gym environment name ("SpaceInvadersNoFrameskip-v4").
NPROCS Constant (i64) Yes Number of parallel environment processes (16).
NSTEPS Constant (i64) Yes Number of steps per update (5).
NSTACK Constant (i64) Yes Number of frames to stack (4).
weight_file Path (for sample) For sampling Path to saved model weights (.ot file).

Outputs

Name Type Description
stdout Text Training progress: update index, total episodes, and average reward per episode (every 500 updates).
a2cN.ot File (weights) Model checkpoints saved every 10,000 updates.
/dev/shm frames Image files Observation frames saved during sampling.

Usage Examples

use super::vec_gym_env::VecGymEnv;
use tch::kind::{FLOAT_CPU, INT64_CPU};
use tch::{nn, nn::OptimizerConfig, Kind::Float, Tensor};

const ENV_NAME: &str = "SpaceInvadersNoFrameskip-v4";
const NPROCS: i64 = 16;
const NSTEPS: i64 = 5;
const NSTACK: i64 = 4;

// Build the actor-critic model
let vs = nn::VarStore::new(tch::Device::cuda_if_available());
let model = model(&vs.root(), env.action_space());
let mut opt = nn::Adam::default().build(&vs, 1e-4).unwrap();

// Collect rollout data
let mut frame_stack = FrameStack::new(NPROCS, NSTACK);
let _ = frame_stack.update(&env.reset()?, None);

for s in 0..NSTEPS {
    let (critic, actor) = tch::no_grad(|| model(&s_states.get(s)));
    let probs = actor.softmax(-1, Float);
    let actions = probs.multinomial(1, true).squeeze_dim(-1);
    let step = env.step(Vec::<i64>::try_from(&actions).unwrap())?;
    // ... store rewards, masks, observations ...
}

// Compute loss and update
let loss = value_loss * 0.5 + action_loss - dist_entropy * 0.01;
opt.backward_step_clip(&loss, 0.5);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment