Implementation:Shiyu coder Kronos Train Model Predictor Qlib
| Field | Value |
|---|---|
| implementation_name | Train_Model_Predictor_Qlib |
| type | API Doc |
| repository | https://github.com/shiyu-coder/Kronos |
| source_file | finetune/train_predictor.py:L60-179 (train_model function) |
| implements | Principle:Shiyu_coder_Kronos_Predictor_Finetuning |
| last_updated | 2026-02-09 14:00 GMT |
Summary
The train_model function implements the main training and validation loop for fine-tuning the autoregressive Transformer predictor using cross-entropy loss on next-token prediction with a frozen tokenizer.
Function Signature
def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_size) -> dt_result
Also
main(config: dict)at lines 182-244: Orchestrates DDP setup, model loading, and training invocation
Import / Invocation
Run as a script via torchrun:
torchrun --standalone --nproc_per_node=N finetune/train_predictor.py
Dependencies
torch,torch.distributedtorch.nn.parallel.DistributedDataParalleltorch.utils.data.DataLoader,torch.utils.data.distributed.DistributedSamplercomet_mlmodel.kronos.KronosTokenizer,model.kronos.Kronosdataset.QlibDataset
Parameters
| Parameter | Type | Description |
|---|---|---|
model |
DDP-wrapped model | The Kronos predictor wrapped in DistributedDataParallel
|
tokenizer |
KronosTokenizer |
Frozen fine-tuned tokenizer (eval mode, on device) |
device |
torch.device |
Device for the current process |
config |
dict |
Configuration dictionary (from Config.__dict__)
|
save_dir |
str |
Directory for saving checkpoints |
logger |
comet_ml.Experiment or None |
Comet logger instance (only on rank 0) |
rank |
int |
Global rank of the current process |
world_size |
int |
Total number of processes |
Input
- Predictor model:
Kronosloaded fromConfig.pretrained_predictor_path - Frozen tokenizer:
KronosTokenizerloaded fromConfig.finetuned_tokenizer_path(result of tokenizer finetuning step) - Data:
QlibDatasetfor train and validation splits
Output
- Return value:
dt_resultdict containing{'best_val_loss': float} - Side effect: Fine-tuned predictor saved to
{save_dir}/checkpoints/best_model
Training Loop Core
On-the-fly Tokenization
# Tokenize input data with frozen tokenizer
with torch.no_grad():
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
# Prepare inputs and targets for next-token prediction
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
Forward Pass and Loss
# Forward pass through predictor
logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
# DualHead cross-entropy loss
loss, s1_loss, s2_loss = model.module.head.compute_loss(
logits[0], logits[1], token_out[0], token_out[1]
)
The DualHead.compute_loss method computes cross-entropy loss for both s1 and s2 logits against their respective target token sequences and returns:
loss: Combined losss1_loss: Stage-1 head cross-entropys2_loss: Stage-2 head cross-entropy
Optimization
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config['predictor_learning_rate'], # 4e-5
betas=(config['adam_beta1'], config['adam_beta2']), # (0.9, 0.95)
weight_decay=config['adam_weight_decay'] # 0.1
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=config['predictor_learning_rate'],
steps_per_epoch=len(train_loader), epochs=config['epochs'],
pct_start=0.03, div_factor=10
)
Gradient clipping is applied with max norm 3.0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)
Distributed Validation
Validation loss is aggregated across ranks:
val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device)
val_batches_tensor = torch.tensor(val_batches_processed_rank, device=device)
dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(val_batches_tensor, op=dist.ReduceOp.SUM)
avg_val_loss = val_loss_sum_tensor.item() / val_batches_tensor.item()
main() Function
The main(config: dict) function at lines 182-244:
- Calls
setup_ddp()to initialize distributed environment - Sets random seeds
- Creates save directories (rank 0 only)
- Optionally initializes Comet ML logger (rank 0 only)
- Loads the frozen fine-tuned tokenizer from
config['finetuned_tokenizer_path']and sets toeval() - Loads the pretrained predictor from
config['pretrained_predictor_path'] - Wraps predictor in DDP
- Calls
train_model() - Saves a JSON summary file (rank 0 only)
- Calls
cleanup_ddp()
Key Differences from Tokenizer Training
| Aspect | Tokenizer Training | Predictor Training |
|---|---|---|
| Loss | MSE reconstruction + BSQ | Cross-entropy (DualHead) |
| Learning rate | 2e-4 | 4e-5 |
| Gradient clip norm | 2.0 | 3.0 |
| Gradient accumulation | Yes (configurable) | No |
| Tokenizer role | Being trained | Frozen encoder |
| Time features used | No (only batch_x) |
Yes (batch_x_stamp passed to model)
|
Source Reference
File: finetune/train_predictor.py, lines 60-179 (train_model), lines 182-244 (main).
Environment & Heuristic Links
- Environment:Shiyu_coder_Kronos_PyTorch_CUDA_Environment
- Environment:Shiyu_coder_Kronos_DDP_Multi_GPU_Environment
- Environment:Shiyu_coder_Kronos_Comet_ML_Logging
- Heuristic:Shiyu_coder_Kronos_Two_Stage_Finetuning_Strategy
- Heuristic:Shiyu_coder_Kronos_Learning_Rate_And_Optimizer_Tuning
- Heuristic:Shiyu_coder_Kronos_Gradient_Clipping_Strategy