Heuristic:ARISE Initiative Robomimic BatchNorm To GroupNorm For EMA
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Optimization |
| Last Updated | 2026-02-15 07:30 GMT |
Overview
When using Exponential Moving Average (EMA) with Diffusion Policy, all BatchNorm layers must be replaced with GroupNorm to prevent severe performance degradation.
Description
EMA maintains a shadow copy of model weights using an exponential moving average. BatchNorm layers maintain running statistics (mean and variance) that are updated during the forward pass, not through gradient-based weight updates. When EMA averages the weights, it incorrectly blends these running statistics, causing the averaged model to produce poor predictions. GroupNorm computes normalization per-group within each sample, requiring no running statistics, and thus works correctly with EMA weight averaging.
Usage
Apply this heuristic whenever implementing Diffusion Policy or any model that combines observation encoders with EMA-based weight averaging. This is particularly critical because the performance degradation is silent — training metrics may look reasonable but evaluation performance will be poor. The code explicitly warns: "performance will tank if you forget to do this!"
The Insight (Rule of Thumb)
- Action: Call `replace_bn_with_gn(obs_encoder)` on the observation encoder before using it with EMA.
- Value: All `nn.BatchNorm` layers are recursively replaced with `nn.GroupNorm`.
- Trade-off: GroupNorm is slightly slower than BatchNorm for large batch sizes, but this difference is negligible compared to the correctness gained.
- Compatibility: Also requires setting `use_cache=False` during training when using transformers, as cached states interfere with gradient computation.
Reasoning
From `robomimic/algo/diffusion_policy.py:66-69`:
# IMPORTANT!
# replace all BatchNorm with GroupNorm to work with EMA
# performance will tank if you forget to do this!
obs_encoder = replace_bn_with_gn(obs_encoder)
BatchNorm tracks `running_mean` and `running_var` as buffers (not parameters). EMA averages parameters but buffers get inconsistently handled, leading to normalization statistics that don't correspond to any real data distribution. GroupNorm avoids this entirely by computing statistics from the input at each forward pass.
The Diffusion Policy paper (Chi et al.) and implementation require this replacement for correct behavior with the `EMAModel` from the `diffusers` library (configured with decay `power=0.75` in the default config).