Heuristic:Hpcaitech ColossalAI Empty Cache Between Phases
| Knowledge Sources | |
|---|---|
| Domains | Memory_Management, Distributed_Training |
| Last Updated | 2026-02-09 03:00 GMT |
Overview
Call `torch.cuda.empty_cache()` between training phases, before checkpointing, and between producer/consumer transitions to prevent OOM from CUDA memory fragmentation.
Description
CUDA's memory allocator caches freed GPU memory blocks for reuse, but these cached blocks can cause fragmentation when allocation patterns change between training phases. In ColossalAI's RLHF pipeline (GRPO), the producer (rollout generation) and consumer (policy training) phases have very different memory profiles. Explicitly clearing the CUDA cache at phase boundaries releases fragmented blocks back to the GPU, allowing the next phase to allocate contiguous memory. The same principle applies before saving checkpoints, where the serialization process requires temporary memory.
Usage
Insert `torch.cuda.empty_cache()` (or `accelerator.empty_cache()`) at these key points:
- Before and after model state synchronization between producer and consumer
- Before saving model checkpoints
- After deleting large temporary tensors (e.g., `state_dict`)
The Insight (Rule of Thumb)
- Action: Call `torch.cuda.empty_cache()` at every phase transition boundary.
- Pattern: Pair with `del` on large tensors immediately before the cache clear.
- Frequency: At least once before each model sync, once after each model sync, and once before each checkpoint save.
- Cost: Negligible time overhead (~1-5ms) compared to the cost of an OOM crash and restart.
Reasoning
PyTorch's CUDA caching allocator holds freed memory in pools segmented by block size. When a training phase ends and a new phase begins with different tensor shapes (e.g., switching from variable-length rollout generation to fixed-batch policy gradient updates), the cached blocks may not match the new allocation requests. This causes PyTorch to request new memory from CUDA while old blocks remain cached, eventually exhausting GPU memory. Calling `empty_cache()` returns all cached blocks to CUDA, defragmenting the available memory pool.
Code Evidence
From `applications/ColossalChat/coati/distributed/consumer.py:149-167` (consumer loop start):
def loop(self) -> None:
self.profiler.enter("sync_model")
torch.cuda.empty_cache()
state_dict = self.state_dict()
...
del state_dict
torch.cuda.empty_cache()
self.profiler.exit("sync_model")
From `applications/ColossalChat/coati/distributed/consumer.py:337-353` (mid-episode sync):
self.profiler.enter("sync_model")
torch.cuda.empty_cache()
state_dict = self.state_dict()
...
del state_dict
torch.cuda.empty_cache()
self.profiler.exit("sync_model")
From `applications/ColossalChat/coati/distributed/producer.py:220-240` (producer loop start):
torch.cuda.empty_cache()
self.profiler.enter("sync_model")
...
self.load_state_dict(state_dict)
self.profiler.exit("sync_model")
del state_dict
torch.cuda.empty_cache()
From `applications/ColossalChat/coati/distributed/producer.py:376-401` (mid-episode sync):
torch.cuda.empty_cache()
self.profiler.enter("sync_model")
...
self.profiler.exit("sync_model")
del state_dict
torch.cuda.empty_cache()
From `applications/Colossal-LLaMA/train.py:323` (before checkpoint save):
accelerator.empty_cache()
save_checkpoint(...)