Implementation:LaurentMazare Tch rs DDPG Agent
Overview
ddpg.rs is a Rust implementation of the Deep Deterministic Policy Gradient (DDPG) algorithm located at examples/reinforcement-learning/ddpg.rs (343 lines) in the tch-rs repository. DDPG is an actor-critic reinforcement learning algorithm designed for continuous action spaces, as described in "Continuous control with deep reinforcement learning" (Lillicrap et al., 2015). This implementation demonstrates how to build a complete deep RL training loop using the tch-rs bindings to PyTorch, including neural network construction, experience replay, Ornstein-Uhlenbeck exploration noise, target network soft updates, and episodic training on the Pendulum-v1 environment via a Python Gym bridge.
Code Reference
Constants
| Constant | Value | Purpose |
|---|---|---|
GAMMA |
0.99 | Discount factor for future rewards in the Bellman equation |
TAU |
0.005 | Soft update weight for target network tracking |
REPLAY_BUFFER_CAPACITY |
100,000 | Maximum number of transitions stored in the replay buffer |
TRAINING_BATCH_SIZE |
100 | Number of transitions sampled per training iteration |
MAX_EPISODES |
100 | Total number of training episodes |
EPISODE_LENGTH |
200 | Maximum number of steps per episode |
TRAINING_ITERATIONS |
200 | Number of gradient updates after each episode |
MU |
0.0 | Mean of the Ornstein-Uhlenbeck process |
THETA |
0.15 | Mean-reversion rate of the OU process |
SIGMA |
0.1 | Volatility (noise magnitude) of the OU process |
ACTOR_LEARNING_RATE |
1e-4 | Learning rate for the actor network (Adam optimizer) |
CRITIC_LEARNING_RATE |
1e-3 | Learning rate for the critic network (Adam optimizer) |
OuNoise
struct OuNoise {
mu: f64,
theta: f64,
sigma: f64,
state: Tensor,
}
Implements the Ornstein-Uhlenbeck process for temporally correlated exploration noise in continuous action spaces.
Key methods:
fn new(mu: f64, theta: f64, sigma: f64, num_actions: usize) -> Self-- Initializes the noise process with state set to ones of sizenum_actions.fn sample(&mut self) -> &Tensor-- Computes one OU step:dx = theta * (mu - state) + sigma * N(0,1), updates internal state, and returns a reference to the noise tensor.
ReplayBuffer
struct ReplayBuffer {
obs: Tensor,
next_obs: Tensor,
rewards: Tensor,
actions: Tensor,
capacity: usize,
len: usize,
i: usize,
}
A fixed-capacity circular replay buffer that stores transitions as pre-allocated tensors on CPU.
Key methods:
fn new(capacity: usize, num_obs: usize, num_actions: usize) -> Self-- Allocates zero-filled tensors of shape[capacity, num_obs],[capacity, num_actions], and[capacity, 1]for rewards.fn push(&mut self, obs: &Tensor, actions: &Tensor, reward: &Tensor, next_obs: &Tensor)-- Stores a transition at the current circular index. Usescopy_for in-place tensor updates.fn random_batch(&self, batch_size: usize) -> Option<(Tensor, Tensor, Tensor, Tensor)>-- Returns a random batch of(states, actions, rewards, next_states)usingTensor::randintfor index sampling. ReturnsNoneif fewer than 3 transitions are stored.
Actor
struct Actor {
var_store: nn::VarStore,
network: nn::Sequential,
device: Device,
num_obs: usize,
num_actions: usize,
opt: nn::Optimizer,
learning_rate: f64,
}
The policy network that maps observations to continuous actions.
Architecture:
- Linear layer:
num_obs -> 400, followed by ReLU - Linear layer:
400 -> 300, followed by ReLU - Linear layer:
300 -> num_actions, followed by tanh (output bounded to [-1, 1])
Key methods:
fn new(num_obs: usize, num_actions: usize, learning_rate: f64) -> Self-- Builds the network with an Adam optimizer.fn forward(&self, obs: &Tensor) -> Tensor-- Runs a forward pass, moving observations to the correct device.impl Clone for Actor-- Deep clones the actor by creating a new instance and copying theVarStore.
Critic
struct Critic {
var_store: nn::VarStore,
network: nn::Sequential,
device: Device,
num_obs: usize,
num_actions: usize,
opt: nn::Optimizer,
learning_rate: f64,
}
The Q-value network that estimates the value of a state-action pair.
Architecture:
- Linear layer:
(num_obs + num_actions) -> 400, followed by ReLU - Linear layer:
400 -> 300, followed by ReLU - Linear layer:
300 -> 1(scalar Q-value, no activation)
Key methods:
fn new(num_obs: usize, num_actions: usize, learning_rate: f64) -> Self-- Builds the network. The input dimension is the concatenation of observation and action dimensions.fn forward(&self, obs: &Tensor, actions: &Tensor) -> Tensor-- Concatenates actions and observations along dimension 1, then applies the network.impl Clone for Critic-- Deep clones viaVarStore::copy.
Helper Function: track
fn track(dest: &mut nn::VarStore, src: &nn::VarStore, tau: f64)
Performs Polyak (soft) averaging to update target network parameters: dest = tau * src + (1 - tau) * dest. Operates inside tch::no_grad to avoid tracking gradients.
Agent
struct Agent {
actor: Actor,
actor_target: Actor,
critic: Critic,
critic_target: Critic,
replay_buffer: ReplayBuffer,
ou_noise: OuNoise,
train: bool,
gamma: f64,
tau: f64,
}
The top-level DDPG agent that orchestrates all components.
Key methods:
fn new(actor, critic, ou_noise, replay_buffer_capacity, train, gamma, tau) -> Self-- Clones actor and critic to create target networks; allocates the replay buffer.fn actions(&mut self, obs: &Tensor) -> Tensor-- Computes actions using the actor (insideno_grad) and adds OU noise if in training mode.fn remember(&mut self, obs, actions, reward, next_obs)-- Stores a transition in the replay buffer.fn train(&mut self, batch_size: usize)-- Performs one DDPG training step:- Samples a random batch from the replay buffer (returns early if insufficient data).
- Computes target Q-values:
q_target = reward + gamma * critic_target(next_state, actor_target(next_state)), detached from the computation graph. - Updates the critic by minimizing MSE between predicted and target Q-values.
- Updates the actor by maximizing the critic's Q-value estimate for the actor's chosen actions (negated for gradient descent).
- Soft-updates both target networks via
track.
Entry Point
pub fn run() -> cpython::PyResult<()>
Creates a Pendulum-v1 environment via the GymEnv Python bridge, constructs the agent with the default hyperparameters, and runs the training loop:
- For each of
MAX_EPISODESepisodes, collects up toEPISODE_LENGTHsteps. - Actions are scaled by 2.0 and clamped to [-2.0, 2.0] (matching Pendulum-v1's action range).
- After each episode, performs
TRAINING_ITERATIONSgradient updates. - Prints the total reward per episode.
I/O Contract
Inputs
- Environment: Pendulum-v1 from OpenAI Gym, accessed through
GymEnv(a cpython bridge defined in the same example directory). - Observation space: Continuous vector (3-dimensional for Pendulum-v1: cos(theta), sin(theta), angular velocity).
- Action space: Continuous scalar (1-dimensional for Pendulum-v1, range [-2.0, 2.0]).
Outputs
- Actions: Continuous-valued tensors produced by the actor network, bounded by tanh to [-1, 1], then scaled and clamped to [-2.0, 2.0].
- Console output: Per-episode total reward, action space dimensions, and observation space shape.
- Trained model: The actor and critic networks are trained in-place (no model saving is performed in this example).
Invariants
- The replay buffer requires at least 3 stored transitions before training can begin.
- Target networks are always initialized as exact copies of the online networks.
- OU noise is only applied during training mode (
self.train == true). - All tensor operations run on CPU (
tch::Device::Cpu).
Algorithm Summary
The DDPG update follows this sequence each training step:
- Sample a mini-batch of
(s, a, r, s')from the replay buffer. - Critic update: Compute
y = r + gamma * Q_target(s', mu_target(s')). MinimizeMSE(Q(s, a), y). - Actor update: Maximize
Q(s, mu(s))with respect to the actor parameters (equivalently, minimize-Q(s, mu(s))). - Soft update: Move target networks toward online networks using
tau.
Dependencies
tch-- Rust bindings to libtorch for tensor operations, neural network modules, and optimizerscpython-- Python interoperability for accessing OpenAI Gym environmentssuper::gym_env::GymEnv-- Rust wrapper around Python Gym environments defined in the same example module
Related Pages
- Principle:LaurentMazare_Tch_rs_Deep_Deterministic_Policy_Gradient -- The guiding principle behind the DDPG algorithm and its application to continuous control