Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Shiyu coder Kronos Train Model Predictor Qlib

From Leeroopedia


Field Value
implementation_name Train_Model_Predictor_Qlib
type API Doc
repository https://github.com/shiyu-coder/Kronos
source_file finetune/train_predictor.py:L60-179 (train_model function)
implements Principle:Shiyu_coder_Kronos_Predictor_Finetuning
last_updated 2026-02-09 14:00 GMT

Summary

The train_model function implements the main training and validation loop for fine-tuning the autoregressive Transformer predictor using cross-entropy loss on next-token prediction with a frozen tokenizer.

Function Signature

def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_size) -> dt_result

Also

  • main(config: dict) at lines 182-244: Orchestrates DDP setup, model loading, and training invocation

Import / Invocation

Run as a script via torchrun:

torchrun --standalone --nproc_per_node=N finetune/train_predictor.py

Dependencies

  • torch, torch.distributed
  • torch.nn.parallel.DistributedDataParallel
  • torch.utils.data.DataLoader, torch.utils.data.distributed.DistributedSampler
  • comet_ml
  • model.kronos.KronosTokenizer, model.kronos.Kronos
  • dataset.QlibDataset

Parameters

Parameter Type Description
model DDP-wrapped model The Kronos predictor wrapped in DistributedDataParallel
tokenizer KronosTokenizer Frozen fine-tuned tokenizer (eval mode, on device)
device torch.device Device for the current process
config dict Configuration dictionary (from Config.__dict__)
save_dir str Directory for saving checkpoints
logger comet_ml.Experiment or None Comet logger instance (only on rank 0)
rank int Global rank of the current process
world_size int Total number of processes

Input

  • Predictor model: Kronos loaded from Config.pretrained_predictor_path
  • Frozen tokenizer: KronosTokenizer loaded from Config.finetuned_tokenizer_path (result of tokenizer finetuning step)
  • Data: QlibDataset for train and validation splits

Output

  • Return value: dt_result dict containing {'best_val_loss': float}
  • Side effect: Fine-tuned predictor saved to {save_dir}/checkpoints/best_model

Training Loop Core

On-the-fly Tokenization

# Tokenize input data with frozen tokenizer
with torch.no_grad():
    token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)

# Prepare inputs and targets for next-token prediction
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]

Forward Pass and Loss

# Forward pass through predictor
logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])

# DualHead cross-entropy loss
loss, s1_loss, s2_loss = model.module.head.compute_loss(
    logits[0], logits[1], token_out[0], token_out[1]
)

The DualHead.compute_loss method computes cross-entropy loss for both s1 and s2 logits against their respective target token sequences and returns:

  • loss: Combined loss
  • s1_loss: Stage-1 head cross-entropy
  • s2_loss: Stage-2 head cross-entropy

Optimization

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['predictor_learning_rate'],   # 4e-5
    betas=(config['adam_beta1'], config['adam_beta2']),  # (0.9, 0.95)
    weight_decay=config['adam_weight_decay']  # 0.1
)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=config['predictor_learning_rate'],
    steps_per_epoch=len(train_loader), epochs=config['epochs'],
    pct_start=0.03, div_factor=10
)

Gradient clipping is applied with max norm 3.0:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)

Distributed Validation

Validation loss is aggregated across ranks:

val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device)
val_batches_tensor = torch.tensor(val_batches_processed_rank, device=device)
dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(val_batches_tensor, op=dist.ReduceOp.SUM)
avg_val_loss = val_loss_sum_tensor.item() / val_batches_tensor.item()

main() Function

The main(config: dict) function at lines 182-244:

  1. Calls setup_ddp() to initialize distributed environment
  2. Sets random seeds
  3. Creates save directories (rank 0 only)
  4. Optionally initializes Comet ML logger (rank 0 only)
  5. Loads the frozen fine-tuned tokenizer from config['finetuned_tokenizer_path'] and sets to eval()
  6. Loads the pretrained predictor from config['pretrained_predictor_path']
  7. Wraps predictor in DDP
  8. Calls train_model()
  9. Saves a JSON summary file (rank 0 only)
  10. Calls cleanup_ddp()

Key Differences from Tokenizer Training

Aspect Tokenizer Training Predictor Training
Loss MSE reconstruction + BSQ Cross-entropy (DualHead)
Learning rate 2e-4 4e-5
Gradient clip norm 2.0 3.0
Gradient accumulation Yes (configurable) No
Tokenizer role Being trained Frozen encoder
Time features used No (only batch_x) Yes (batch_x_stamp passed to model)

Source Reference

File: finetune/train_predictor.py, lines 60-179 (train_model), lines 182-244 (main).

Environment & Heuristic Links

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment