Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Shiyu coder Kronos Generate Predictions Qlib

From Leeroopedia


Field Value
implementation_name Generate_Predictions_Qlib
type API Doc
repository https://github.com/shiyu-coder/Kronos
source_file finetune/qlib_test.py:L239-295
implements Principle:Shiyu_coder_Kronos_Qlib_Test_Inference
last_updated 2026-02-09 14:00 GMT

Summary

The generate_predictions function runs batch inference over the entire test dataset using fine-tuned Kronos models to produce trading signal predictions in DataFrame format for backtesting.

Function Signature

def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFrame]

Also

  • QlibTestDataset at lines 32-89: PyTorch Dataset for sequential test data iteration
  • load_models(config: dict) -> tuple[KronosTokenizer, Kronos] at lines 207-213: Loads fine-tuned models
  • collate_fn_for_inference(batch) at lines 216-236: Custom collate for mixed tensor/string/timestamp batches

Import

from qlib_test import generate_predictions

Dependencies

  • torch
  • numpy
  • pandas
  • tqdm
  • collections.defaultdict
  • model.kronos.Kronos, model.kronos.KronosTokenizer, model.kronos.auto_regressive_inference

Input

Parameter Type Description
config dict Configuration dict with keys: device, tokenizer_path, model_path, max_context, pred_len, clip, T, top_k, top_p, sample_count, batch_size
test_data dict Test data loaded from pickle: dict[symbol -> pd.DataFrame]

Output

A dictionary with keys 'mean', 'last', 'max', 'min', each mapping to a pd.DataFrame with:

  • Index: datetime
  • Columns: symbol names (instrument codes)
  • Values: prediction scores (close price deltas)

QlibTestDataset

class QlibTestDataset(Dataset):
    def __init__(self, data: dict, config: Config)
    def __len__(self) -> int
    def __getitem__(self, idx: int) -> tuple

Unlike the training QlibDataset, this dataset:

  • Iterates sequentially through all windows (no random sampling)
  • Returns 5 items per sample: (x, x_stamp, y_stamp, symbol, timestamp)
  • The y_stamp (future time features) is needed for autoregressive inference
  • Window size is lookback_window + predict_window (no +1, since no overlap with target)
  • Instance normalization is applied identically to training

collate_fn_for_inference

def collate_fn_for_inference(batch) -> tuple

Custom collate function that handles mixed data types in each batch:

  • Tensors (x, x_stamp, y_stamp): Stacked via torch.stack()
  • Strings (symbols): Collected into a list
  • Timestamps (timestamps): Collected into a list

This is necessary because PyTorch's default collate cannot handle non-tensor elements.

load_models

def load_models(config: dict) -> tuple[KronosTokenizer, Kronos]

Loads both models from the paths specified in config, moves them to the target device, and sets them to eval mode.

Inference Core

with torch.no_grad():
    for x, x_stamp, y_stamp, symbols, timestamps in tqdm(loader, desc="Inference"):
        preds = auto_regressive_inference(
            tokenizer, model,
            x.to(device), x_stamp.to(device), y_stamp.to(device),
            max_context=config['max_context'],
            pred_len=config['pred_len'],
            clip=config['clip'],
            T=config['T'],
            top_k=config['top_k'],
            top_p=config['top_p'],
            sample_count=config['sample_count']
        )
        # Keep only the prediction horizon
        preds = preds[:, -config['pred_len']:, :]

Signal Extraction

Close price is at index 3 in the feature list [open, high, low, close, vol, amt]. Signals are computed as deltas from the last observed close:

last_day_close = x[:, -1, 3].numpy()
signals = {
    'last': preds[:, -1, 3] - last_day_close,
    'mean': np.mean(preds[:, :, 3], axis=1) - last_day_close,
    'max':  np.max(preds[:, :, 3], axis=1) - last_day_close,
    'min':  np.min(preds[:, :, 3], axis=1) - last_day_close,
}

Post-processing

Results are collected as (timestamp, symbol, score) tuples and pivoted into DataFrames:

for sig_type, records in results.items():
    df = pd.DataFrame(records, columns=['datetime', 'instrument', 'score'])
    pivot_df = df.pivot_table(index='datetime', columns='instrument', values='score')
    prediction_dfs[sig_type] = pivot_df.sort_index()

Example Usage

import pickle
from qlib_test import generate_predictions

# Load test data
with open("./data/processed_datasets/test_data.pkl", 'rb') as f:
    test_data = pickle.load(f)

# Configure inference
run_config = {
    'device': 'cuda:0',
    'tokenizer_path': './outputs/models/finetune_tokenizer_demo/checkpoints/best_model',
    'model_path': './outputs/models/finetune_predictor_demo/checkpoints/best_model',
    'max_context': 512,
    'pred_len': 10,
    'clip': 5.0,
    'T': 0.6,
    'top_k': 0,
    'top_p': 0.9,
    'sample_count': 5,
    'batch_size': 1000,
}

# Generate prediction signals
prediction_dfs = generate_predictions(run_config, test_data)
# prediction_dfs keys: 'mean', 'last', 'max', 'min'
# Each value is a pd.DataFrame (datetime index x symbol columns)

Source Reference

File: finetune/qlib_test.py, lines 239-295 (generate_predictions), lines 32-89 (QlibTestDataset), lines 207-213 (load_models), lines 216-236 (collate_fn_for_inference).

Environment & Heuristic Links

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment