Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Shiyu coder Kronos Tokenizer Finetuning

From Leeroopedia


Field Value
principle_name Tokenizer_Finetuning
repository https://github.com/shiyu-coder/Kronos
domains Training, VQ_VAE, Distributed_Training
implemented_by Implementation:Shiyu_coder_Kronos_Train_Model_Tokenizer_Qlib
last_updated 2026-02-09 14:00 GMT

Summary

Fine-tuning the VQ-VAE tokenizer using reconstruction loss and BSQ quantization loss with DDP-distributed training on domain-specific financial data.

Concept

The Tokenizer Finetuning principle describes how a pretrained VQ-VAE (Vector Quantized Variational Autoencoder) tokenizer is adapted to a new financial data domain. The tokenizer converts continuous time series into discrete token sequences, which the autoregressive predictor model consumes. Fine-tuning ensures the tokenizer learns an effective codebook for the target domain's data distribution.

Theory

Two-Component Loss

The tokenizer training objective combines two loss components:

  • Reconstruction loss: Measures how well the tokenizer can reconstruct the original input from its discrete representation. This uses MSE (Mean Squared Error) between the decoded output and the original input. Two reconstruction signals are computed:
    • z_pre: Output from the stage-1 (s1-only) decoder path
    • z: Output from the full decoder path (both s1 and s2)
    • Total reconstruction loss: recon_loss = MSE(z_pre, input) + MSE(z, input)
  • BSQ quantization loss: The Binary Stochastic Quantization loss that encourages the codebook to be well-utilized and the encoder outputs to be close to codebook entries. This is computed internally by the tokenizer model.
  • Combined loss: (recon_loss + bsq_loss) / 2

Optimization Strategy

  • Optimizer: AdamW with configurable weight decay (default 0.1)
  • Scheduler: OneCycleLR with warmup phase (3% of total steps), providing a learning rate that ramps up then decays
  • Gradient accumulation: Supports splitting each batch into sub-batches for effectively larger batch sizes without increased memory usage
  • Gradient clipping: Max norm of 2.0 to prevent exploding gradients

Distributed Training

The training uses PyTorch's Distributed Data Parallel (DDP) for multi-GPU scaling:

  • Each GPU processes a portion of each batch independently
  • Gradients are synchronized across GPUs via all-reduce operations
  • Validation loss is aggregated across all ranks using dist.all_reduce
  • Only the master process (rank 0) performs checkpointing and logging
  • dist.barrier() ensures synchronization between epochs

Validation and Checkpointing

  • After each training epoch, a validation loop computes MSE loss using only the full decoder output (z)
  • The model checkpoint with the lowest validation loss is saved as best_model
  • Training logs are sent to Comet ML for experiment tracking

Domains

  • Training: Model fine-tuning with multi-component loss functions
  • VQ_VAE: Vector Quantized Variational Autoencoder architecture
  • Distributed_Training: DDP-based multi-GPU training with gradient synchronization

See Also

Heuristic Links

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment