Heuristic:Danijar Dreamerv3 Adaptive Gradient Clipping
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep_Learning, Reinforcement_Learning |
| Last Updated | 2026-02-15 09:00 GMT |
Overview
Gradient clipping technique that scales update norms relative to parameter norms, providing scale-invariant gradient control without requiring manual threshold tuning.
Description
Adaptive Gradient Clipping (AGC) clips the gradient update for each parameter tensor based on the ratio of the update norm to the parameter norm. If the update norm exceeds `clip * max(pmin, param_norm)`, the update is scaled down proportionally. This replaces traditional global gradient clipping with a scale-aware mechanism that adapts to the natural magnitude of each parameter.
DreamerV3 uses AGC as the first step in its custom optimizer chain (before RMS scaling and momentum), with a default clip ratio of 0.3 and a minimum parameter norm floor of 1e-3.
Usage
Applied automatically as part of the DreamerV3 optimizer chain. AGC is the first gradient transform applied before `scale_by_rms` and `scale_by_momentum`. Use when training deep networks where different layers may have vastly different parameter and gradient scales.
The Insight (Rule of Thumb)
- Action: Use AGC with `clip=0.3` and `pmin=1e-3` as the first gradient transform in the optimizer chain.
- Value: `agc=0.3` is the default. The `pmin=1e-3` floor prevents division-by-zero for freshly initialized or near-zero parameters.
- Trade-off: AGC is slightly more expensive than global clipping (per-parameter norm computation), but eliminates the need to tune a global gradient clipping threshold which varies across model sizes and tasks.
- Compatibility: Works with any optimizer chain. In DreamerV3, it is combined with a custom RMS-based Adam variant (not standard optax.adam).
Reasoning
Global gradient clipping (e.g., `max_norm=100`) requires tuning the threshold for each model size and task. AGC normalizes updates relative to each parameter's own scale, making it invariant to model architecture changes. This is essential for DreamerV3's goal of fixed hyperparameters across diverse domains.
The implementation from `embodied/jax/opt.py:L109-123`:
def clip_by_agc(clip=0.3, pmin=1e-3):
def init_fn(params):
return ()
def update_fn(updates, state, params=None):
def fn(param, update):
unorm = jnp.linalg.norm(update.flatten(), 2)
pnorm = jnp.linalg.norm(param.flatten(), 2)
upper = clip * jnp.maximum(pmin, pnorm)
return update * (1 / jnp.maximum(1.0, unorm / upper))
updates = jax.tree.map(fn, params, updates) if clip else updates
return updates, ()
return optax.GradientTransformation(init_fn, update_fn)
The optimizer chain assembly from `dreamerv3/agent.py:L357-379`:
chain = []
chain.append(embodied.jax.opt.clip_by_agc(agc)) # 1. AGC
chain.append(embodied.jax.opt.scale_by_rms(beta2, eps)) # 2. RMS scaling
chain.append(embodied.jax.opt.scale_by_momentum(beta1, nesterov)) # 3. Momentum
# ... weight decay, learning rate schedule
return optax.chain(*chain)
Default config from `dreamerv3/configs.yaml:L87`:
opt: {lr: 4e-5, agc: 0.3, eps: 1e-20, beta1: 0.9, beta2: 0.999}