Principle:NVIDIA TransformerEngine Distributed Initialization
Overview
Managing reproducible random number generation across distributed GPU processes for tensor-parallel training.
Description
In tensor-parallel training, different GPUs need different random states for operations like dropout (to avoid correlated noise) while sharing the same state for weight initialization (to ensure consistent initial parameters). CudaRNGStatesTracker provides named, forkable CUDA RNG states for this purpose.
The core challenge is that tensor parallelism splits activations and parameters across GPUs. Without careful RNG management:
- Dropout masks would be identical across tensor-parallel ranks if they share the same RNG state, introducing correlated noise that degrades training quality.
- Weight initialization would produce different values on each rank if they use different RNG states, breaking the mathematical equivalence between tensor-parallel and single-GPU execution.
The solution is to maintain multiple named RNG states that can be selectively activated:
- A model-parallel-rng state, seeded differently per TP rank, for operations like dropout that must be uncorrelated.
- The default CUDA RNG state for operations that must be consistent across ranks.
Theoretical Basis
Each GPU maintains independent CUDA RNG states that can be switched via context managers. The mechanism works as follows:
- Registration: Each named state is registered with a unique seed via
add(name, seed). The seed typically incorporates the tensor-parallel rank to ensure different random sequences per GPU. - Forking: The
fork(name)context manager saves the current CUDA RNG state, switches to the named state, executes the enclosed operations, then restores the original state. - Reproducibility: Because the seeds are deterministic functions of rank and configuration, the entire training run is reproducible given the same initial conditions.
The "model-parallel-rng" state ensures consistent randomness within tensor-parallel groups, while distinct dropout states prevent correlated noise across ranks.
Usage
Use when initializing TE models for tensor-parallel or FSDP training. The CudaRNGStatesTracker is required by TransformerLayer's get_rng_state_tracker parameter.
Typical setup pattern:
- Create a tracker instance.
- Register the "model-parallel-rng" state with a rank-dependent seed.
- Pass a getter function to TE modules so they can fork the RNG state during forward pass.
This is essential for:
- Tensor-parallel training where dropout must be uncorrelated across ranks.
- Activation checkpointing where the forward pass is replayed during backward, requiring exact RNG state reproduction.
- Deterministic training where full reproducibility across runs is required.