Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Shiyu coder Kronos Gradient Clipping Strategy

From Leeroopedia




Knowledge Sources
Domains Training, Optimization
Last Updated 2026-02-09 13:47 GMT

Overview

Differentiated gradient clipping strategy: max_norm=2.0 for tokenizer training and max_norm=3.0 for predictor training, with gradient accumulation support.

Description

Kronos applies different gradient clipping thresholds to the tokenizer and predictor during finetuning. The tokenizer uses a tighter clip (max_norm=2.0) because its BSQ quantization loss can produce sharp gradient spikes. The predictor uses a looser clip (max_norm=3.0) because cross-entropy loss on token classification typically has larger but more stable gradients. Both use `torch.nn.utils.clip_grad_norm_` which clips the total gradient norm across all parameters. The tokenizer also supports gradient accumulation to simulate larger batch sizes on memory-limited hardware.

Usage

Use this heuristic when:

  • Training is diverging: If loss explodes, try reducing max_norm
  • Training is too slow to converge: If gradients are always being clipped, try increasing max_norm
  • Using gradient accumulation: Loss is scaled by `1/accumulation_steps` before backward; gradient clipping happens after all accumulation steps

The Insight (Rule of Thumb)

  • Action: Apply `clip_grad_norm_` with different thresholds per model component.
  • Value:
    • Tokenizer: `max_norm=2.0`
    • Predictor: `max_norm=3.0`
  • Gradient accumulation: Default `accumulation_steps=1`. When >1, loss is divided by accumulation_steps before backward pass. Optimizer step happens after all sub-batches.
  • Trade-off: Tighter clipping (lower max_norm) prevents gradient explosion but may slow convergence. The tokenizer needs tighter control due to its quantization loss dynamics.

Reasoning

The tokenizer combines MSE reconstruction loss with BSQ quantization loss (commit loss + entropy penalty). The BSQ loss operates on discrete codebook assignments and can produce sporadic large gradients when codebook entries shift. The max_norm=2.0 prevents these spikes from destabilizing training.

The predictor uses cross-entropy loss on token classification (dual-head: S1 + S2). Cross-entropy gradients are bounded by `log(vocab_size)` in the worst case, making them more predictable. The higher max_norm=3.0 allows the predictor to take larger gradient steps when needed.

Gradient accumulation enables effective batch size scaling without increasing memory usage: `effective_batch = batch_size * world_size * accumulation_steps`.

Evidence from `finetune/train_tokenizer.py:146-151`:

loss_scaled = loss / config['accumulation_steps']
current_batch_total_loss += loss.item()
loss_scaled.backward()

# --- Optimizer Step after Accumulation ---
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)

Predictor gradient clipping from `finetune/train_predictor.py:112-114`:

optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)

Gradient accumulation config from `finetune/config.py:61-62`:

# Gradient accumulation to simulate a larger batch size.
self.accumulation_steps = 1

Effective batch size logging from `finetune/train_tokenizer.py:92-94`:

effective_bs = config['batch_size'] * world_size * config['accumulation_steps']
print(f"[Rank {rank}] BATCHSIZE (per GPU): {config['batch_size']}")
print(f"[Rank {rank}] Effective total batch size: {effective_bs}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment