Implementation:Shiyu coder Kronos Train Model Tokenizer Qlib
| Field | Value |
|---|---|
| implementation_name | Train_Model_Tokenizer_Qlib |
| type | API Doc |
| repository | https://github.com/shiyu-coder/Kronos |
| source_file | finetune/train_tokenizer.py:L74-215 (train_model function) |
| implements | Principle:Shiyu_coder_Kronos_Tokenizer_Finetuning |
| last_updated | 2026-02-09 14:00 GMT |
Summary
The train_model function implements the main training and validation loop for fine-tuning the VQ-VAE tokenizer using reconstruction loss and BSQ quantization loss with DDP-distributed training.
Function Signature
def train_model(model, device, config, save_dir, logger, rank, world_size) -> (model, dt_result)
Also
main(config: dict)at lines 218-281: Orchestrates DDP setup, model initialization, and training invocationsetup_ddp()viafinetune/utils/training_utils.py:L9-32: Initializes the distributed process group
Import / Invocation
Run as a script via torchrun:
torchrun --standalone --nproc_per_node=N finetune/train_tokenizer.py
Dependencies
torch,torch.distributed,torch.nn.functionaltorch.nn.parallel.DistributedDataParalleltorch.utils.data.DataLoader,torch.utils.data.distributed.DistributedSamplercomet_mlmodel.kronos.KronosTokenizerdataset.QlibDataset
Parameters
| Parameter | Type | Description |
|---|---|---|
model |
DDP-wrapped model | The KronosTokenizer wrapped in DistributedDataParallel
|
device |
torch.device |
Device for the current process (e.g., cuda:0)
|
config |
dict |
Configuration dictionary (from Config.__dict__)
|
save_dir |
str |
Directory for saving checkpoints |
logger |
comet_ml.Experiment or None |
Comet logger instance (only on rank 0) |
rank |
int |
Global rank of the current process |
world_size |
int |
Total number of processes |
Input
- Model:
KronosTokenizerloaded fromConfig.pretrained_tokenizer_path - Data:
QlibDatasetfor train and validation splits
Output
- Return value: Tuple of
(model, dt_result)wheredt_resultis a dict containing{'best_val_loss': float} - Side effect: Fine-tuned tokenizer saved to
{save_dir}/checkpoints/best_model
Loss Computation
# Forward pass through tokenizer
zs, bsq_loss, _, _ = model(batch_x)
z_pre, z = zs
# Reconstruction losses
recon_loss_pre = F.mse_loss(z_pre, batch_x) # Stage-1 decoder reconstruction
recon_loss_all = F.mse_loss(z, batch_x) # Full decoder reconstruction
# Combined loss
recon_loss = recon_loss_pre + recon_loss_all
loss = (recon_loss + bsq_loss) / 2
Where:
z_pre: Output from the s1-only decoder pathz: Output from the full (s1+s2) decoder pathbsq_loss: Binary Stochastic Quantization loss from the tokenizer's codebook
Training Loop Details
Optimizer and Scheduler
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config['tokenizer_learning_rate'], # 2e-4
weight_decay=config['adam_weight_decay'] # 0.1
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=config['tokenizer_learning_rate'],
steps_per_epoch=len(train_loader),
epochs=config['epochs'], # 30
pct_start=0.03,
div_factor=10
)
Gradient Accumulation
Each batch is split into accumulation_steps sub-batches. The loss is scaled by 1 / accumulation_steps before backward. After all sub-batches, gradients are clipped to max norm 2.0 and the optimizer steps.
Distributed Validation
Validation loss is computed per rank and aggregated via dist.all_reduce(SUM):
# Per-rank accumulation
tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
val_sample_count_rank += ori_batch_x.size(0)
# Cross-rank aggregation
dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(val_count_tensor, op=dist.ReduceOp.SUM)
avg_val_loss = val_loss_sum_tensor.item() / val_count_tensor.item()
Checkpointing
Only rank 0 saves checkpoints. The best model (lowest validation loss) is saved via:
model.module.save_pretrained(f"{save_dir}/checkpoints/best_model")
main() Function
The main(config: dict) function at lines 218-281 orchestrates the full training pipeline:
- Calls
setup_ddp()to initialize distributed environment - Sets random seeds via
set_seed(config['seed'], rank) - Creates save directories (rank 0 only)
- Optionally initializes Comet ML logger (rank 0 only)
- Loads pretrained
KronosTokenizerfromconfig['pretrained_tokenizer_path'] - Wraps model in DDP
- Calls
train_model() - Saves a JSON summary file (rank 0 only)
- Calls
cleanup_ddp()
Source Reference
File: finetune/train_tokenizer.py, lines 74-215 (train_model), lines 218-281 (main).
Environment & Heuristic Links
- Environment:Shiyu_coder_Kronos_PyTorch_CUDA_Environment
- Environment:Shiyu_coder_Kronos_DDP_Multi_GPU_Environment
- Environment:Shiyu_coder_Kronos_Comet_ML_Logging
- Heuristic:Shiyu_coder_Kronos_Two_Stage_Finetuning_Strategy
- Heuristic:Shiyu_coder_Kronos_Learning_Rate_And_Optimizer_Tuning
- Heuristic:Shiyu_coder_Kronos_Gradient_Clipping_Strategy