Principle:Shiyu coder Kronos Tokenizer Finetuning
| Field | Value |
|---|---|
| principle_name | Tokenizer_Finetuning |
| repository | https://github.com/shiyu-coder/Kronos |
| domains | Training, VQ_VAE, Distributed_Training |
| implemented_by | Implementation:Shiyu_coder_Kronos_Train_Model_Tokenizer_Qlib |
| last_updated | 2026-02-09 14:00 GMT |
Summary
Fine-tuning the VQ-VAE tokenizer using reconstruction loss and BSQ quantization loss with DDP-distributed training on domain-specific financial data.
Concept
The Tokenizer Finetuning principle describes how a pretrained VQ-VAE (Vector Quantized Variational Autoencoder) tokenizer is adapted to a new financial data domain. The tokenizer converts continuous time series into discrete token sequences, which the autoregressive predictor model consumes. Fine-tuning ensures the tokenizer learns an effective codebook for the target domain's data distribution.
Theory
Two-Component Loss
The tokenizer training objective combines two loss components:
- Reconstruction loss: Measures how well the tokenizer can reconstruct the original input from its discrete representation. This uses MSE (Mean Squared Error) between the decoded output and the original input. Two reconstruction signals are computed:
z_pre: Output from the stage-1 (s1-only) decoder pathz: Output from the full decoder path (both s1 and s2)- Total reconstruction loss:
recon_loss = MSE(z_pre, input) + MSE(z, input)
- BSQ quantization loss: The Binary Stochastic Quantization loss that encourages the codebook to be well-utilized and the encoder outputs to be close to codebook entries. This is computed internally by the tokenizer model.
- Combined loss:
(recon_loss + bsq_loss) / 2
Optimization Strategy
- Optimizer: AdamW with configurable weight decay (default 0.1)
- Scheduler: OneCycleLR with warmup phase (3% of total steps), providing a learning rate that ramps up then decays
- Gradient accumulation: Supports splitting each batch into sub-batches for effectively larger batch sizes without increased memory usage
- Gradient clipping: Max norm of 2.0 to prevent exploding gradients
Distributed Training
The training uses PyTorch's Distributed Data Parallel (DDP) for multi-GPU scaling:
- Each GPU processes a portion of each batch independently
- Gradients are synchronized across GPUs via all-reduce operations
- Validation loss is aggregated across all ranks using
dist.all_reduce - Only the master process (rank 0) performs checkpointing and logging
dist.barrier()ensures synchronization between epochs
Validation and Checkpointing
- After each training epoch, a validation loop computes MSE loss using only the full decoder output (
z) - The model checkpoint with the lowest validation loss is saved as
best_model - Training logs are sent to Comet ML for experiment tracking
Domains
- Training: Model fine-tuning with multi-component loss functions
- VQ_VAE: Vector Quantized Variational Autoencoder architecture
- Distributed_Training: DDP-based multi-GPU training with gradient synchronization