Heuristic:Shiyu coder Kronos Gradient Clipping Strategy
| Knowledge Sources | |
|---|---|
| Domains | Training, Optimization |
| Last Updated | 2026-02-09 13:47 GMT |
Overview
Differentiated gradient clipping strategy: max_norm=2.0 for tokenizer training and max_norm=3.0 for predictor training, with gradient accumulation support.
Description
Kronos applies different gradient clipping thresholds to the tokenizer and predictor during finetuning. The tokenizer uses a tighter clip (max_norm=2.0) because its BSQ quantization loss can produce sharp gradient spikes. The predictor uses a looser clip (max_norm=3.0) because cross-entropy loss on token classification typically has larger but more stable gradients. Both use `torch.nn.utils.clip_grad_norm_` which clips the total gradient norm across all parameters. The tokenizer also supports gradient accumulation to simulate larger batch sizes on memory-limited hardware.
Usage
Use this heuristic when:
- Training is diverging: If loss explodes, try reducing max_norm
- Training is too slow to converge: If gradients are always being clipped, try increasing max_norm
- Using gradient accumulation: Loss is scaled by `1/accumulation_steps` before backward; gradient clipping happens after all accumulation steps
The Insight (Rule of Thumb)
- Action: Apply `clip_grad_norm_` with different thresholds per model component.
- Value:
- Tokenizer: `max_norm=2.0`
- Predictor: `max_norm=3.0`
- Gradient accumulation: Default `accumulation_steps=1`. When >1, loss is divided by accumulation_steps before backward pass. Optimizer step happens after all sub-batches.
- Trade-off: Tighter clipping (lower max_norm) prevents gradient explosion but may slow convergence. The tokenizer needs tighter control due to its quantization loss dynamics.
Reasoning
The tokenizer combines MSE reconstruction loss with BSQ quantization loss (commit loss + entropy penalty). The BSQ loss operates on discrete codebook assignments and can produce sporadic large gradients when codebook entries shift. The max_norm=2.0 prevents these spikes from destabilizing training.
The predictor uses cross-entropy loss on token classification (dual-head: S1 + S2). Cross-entropy gradients are bounded by `log(vocab_size)` in the worst case, making them more predictable. The higher max_norm=3.0 allows the predictor to take larger gradient steps when needed.
Gradient accumulation enables effective batch size scaling without increasing memory usage: `effective_batch = batch_size * world_size * accumulation_steps`.
Evidence from `finetune/train_tokenizer.py:146-151`:
loss_scaled = loss / config['accumulation_steps']
current_batch_total_loss += loss.item()
loss_scaled.backward()
# --- Optimizer Step after Accumulation ---
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
Predictor gradient clipping from `finetune/train_predictor.py:112-114`:
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)
Gradient accumulation config from `finetune/config.py:61-62`:
# Gradient accumulation to simulate a larger batch size.
self.accumulation_steps = 1
Effective batch size logging from `finetune/train_tokenizer.py:92-94`:
effective_bs = config['batch_size'] * world_size * config['accumulation_steps']
print(f"[Rank {rank}] BATCHSIZE (per GPU): {config['batch_size']}")
print(f"[Rank {rank}] Effective total batch size: {effective_bs}")