Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Shiyu coder Kronos Train Model Tokenizer Qlib

From Leeroopedia


Field Value
implementation_name Train_Model_Tokenizer_Qlib
type API Doc
repository https://github.com/shiyu-coder/Kronos
source_file finetune/train_tokenizer.py:L74-215 (train_model function)
implements Principle:Shiyu_coder_Kronos_Tokenizer_Finetuning
last_updated 2026-02-09 14:00 GMT

Summary

The train_model function implements the main training and validation loop for fine-tuning the VQ-VAE tokenizer using reconstruction loss and BSQ quantization loss with DDP-distributed training.

Function Signature

def train_model(model, device, config, save_dir, logger, rank, world_size) -> (model, dt_result)

Also

  • main(config: dict) at lines 218-281: Orchestrates DDP setup, model initialization, and training invocation
  • setup_ddp() via finetune/utils/training_utils.py:L9-32: Initializes the distributed process group

Import / Invocation

Run as a script via torchrun:

torchrun --standalone --nproc_per_node=N finetune/train_tokenizer.py

Dependencies

  • torch, torch.distributed, torch.nn.functional
  • torch.nn.parallel.DistributedDataParallel
  • torch.utils.data.DataLoader, torch.utils.data.distributed.DistributedSampler
  • comet_ml
  • model.kronos.KronosTokenizer
  • dataset.QlibDataset

Parameters

Parameter Type Description
model DDP-wrapped model The KronosTokenizer wrapped in DistributedDataParallel
device torch.device Device for the current process (e.g., cuda:0)
config dict Configuration dictionary (from Config.__dict__)
save_dir str Directory for saving checkpoints
logger comet_ml.Experiment or None Comet logger instance (only on rank 0)
rank int Global rank of the current process
world_size int Total number of processes

Input

  • Model: KronosTokenizer loaded from Config.pretrained_tokenizer_path
  • Data: QlibDataset for train and validation splits

Output

  • Return value: Tuple of (model, dt_result) where dt_result is a dict containing {'best_val_loss': float}
  • Side effect: Fine-tuned tokenizer saved to {save_dir}/checkpoints/best_model

Loss Computation

# Forward pass through tokenizer
zs, bsq_loss, _, _ = model(batch_x)
z_pre, z = zs

# Reconstruction losses
recon_loss_pre = F.mse_loss(z_pre, batch_x)   # Stage-1 decoder reconstruction
recon_loss_all = F.mse_loss(z, batch_x)        # Full decoder reconstruction

# Combined loss
recon_loss = recon_loss_pre + recon_loss_all
loss = (recon_loss + bsq_loss) / 2

Where:

  • z_pre: Output from the s1-only decoder path
  • z: Output from the full (s1+s2) decoder path
  • bsq_loss: Binary Stochastic Quantization loss from the tokenizer's codebook

Training Loop Details

Optimizer and Scheduler

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['tokenizer_learning_rate'],   # 2e-4
    weight_decay=config['adam_weight_decay']  # 0.1
)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=config['tokenizer_learning_rate'],
    steps_per_epoch=len(train_loader),
    epochs=config['epochs'],    # 30
    pct_start=0.03,
    div_factor=10
)

Gradient Accumulation

Each batch is split into accumulation_steps sub-batches. The loss is scaled by 1 / accumulation_steps before backward. After all sub-batches, gradients are clipped to max norm 2.0 and the optimizer steps.

Distributed Validation

Validation loss is computed per rank and aggregated via dist.all_reduce(SUM):

# Per-rank accumulation
tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
val_sample_count_rank += ori_batch_x.size(0)

# Cross-rank aggregation
dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(val_count_tensor, op=dist.ReduceOp.SUM)
avg_val_loss = val_loss_sum_tensor.item() / val_count_tensor.item()

Checkpointing

Only rank 0 saves checkpoints. The best model (lowest validation loss) is saved via:

model.module.save_pretrained(f"{save_dir}/checkpoints/best_model")

main() Function

The main(config: dict) function at lines 218-281 orchestrates the full training pipeline:

  1. Calls setup_ddp() to initialize distributed environment
  2. Sets random seeds via set_seed(config['seed'], rank)
  3. Creates save directories (rank 0 only)
  4. Optionally initializes Comet ML logger (rank 0 only)
  5. Loads pretrained KronosTokenizer from config['pretrained_tokenizer_path']
  6. Wraps model in DDP
  7. Calls train_model()
  8. Saves a JSON summary file (rank 0 only)
  9. Calls cleanup_ddp()

Source Reference

File: finetune/train_tokenizer.py, lines 74-215 (train_model), lines 218-281 (main).

Environment & Heuristic Links

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment