Implementation:Shiyu coder Kronos Generate Predictions Qlib
| Field | Value |
|---|---|
| implementation_name | Generate_Predictions_Qlib |
| type | API Doc |
| repository | https://github.com/shiyu-coder/Kronos |
| source_file | finetune/qlib_test.py:L239-295 |
| implements | Principle:Shiyu_coder_Kronos_Qlib_Test_Inference |
| last_updated | 2026-02-09 14:00 GMT |
Summary
The generate_predictions function runs batch inference over the entire test dataset using fine-tuned Kronos models to produce trading signal predictions in DataFrame format for backtesting.
Function Signature
def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFrame]
Also
QlibTestDatasetat lines 32-89: PyTorch Dataset for sequential test data iterationload_models(config: dict) -> tuple[KronosTokenizer, Kronos]at lines 207-213: Loads fine-tuned modelscollate_fn_for_inference(batch)at lines 216-236: Custom collate for mixed tensor/string/timestamp batches
Import
from qlib_test import generate_predictions
Dependencies
torchnumpypandastqdmcollections.defaultdictmodel.kronos.Kronos,model.kronos.KronosTokenizer,model.kronos.auto_regressive_inference
Input
| Parameter | Type | Description |
|---|---|---|
config |
dict |
Configuration dict with keys: device, tokenizer_path, model_path, max_context, pred_len, clip, T, top_k, top_p, sample_count, batch_size
|
test_data |
dict |
Test data loaded from pickle: dict[symbol -> pd.DataFrame]
|
Output
A dictionary with keys 'mean', 'last', 'max', 'min', each mapping to a pd.DataFrame with:
- Index: datetime
- Columns: symbol names (instrument codes)
- Values: prediction scores (close price deltas)
QlibTestDataset
class QlibTestDataset(Dataset):
def __init__(self, data: dict, config: Config)
def __len__(self) -> int
def __getitem__(self, idx: int) -> tuple
Unlike the training QlibDataset, this dataset:
- Iterates sequentially through all windows (no random sampling)
- Returns 5 items per sample:
(x, x_stamp, y_stamp, symbol, timestamp) - The
y_stamp(future time features) is needed for autoregressive inference - Window size is
lookback_window + predict_window(no +1, since no overlap with target) - Instance normalization is applied identically to training
collate_fn_for_inference
def collate_fn_for_inference(batch) -> tuple
Custom collate function that handles mixed data types in each batch:
- Tensors (
x,x_stamp,y_stamp): Stacked viatorch.stack() - Strings (
symbols): Collected into a list - Timestamps (
timestamps): Collected into a list
This is necessary because PyTorch's default collate cannot handle non-tensor elements.
load_models
def load_models(config: dict) -> tuple[KronosTokenizer, Kronos]
Loads both models from the paths specified in config, moves them to the target device, and sets them to eval mode.
Inference Core
with torch.no_grad():
for x, x_stamp, y_stamp, symbols, timestamps in tqdm(loader, desc="Inference"):
preds = auto_regressive_inference(
tokenizer, model,
x.to(device), x_stamp.to(device), y_stamp.to(device),
max_context=config['max_context'],
pred_len=config['pred_len'],
clip=config['clip'],
T=config['T'],
top_k=config['top_k'],
top_p=config['top_p'],
sample_count=config['sample_count']
)
# Keep only the prediction horizon
preds = preds[:, -config['pred_len']:, :]
Signal Extraction
Close price is at index 3 in the feature list [open, high, low, close, vol, amt]. Signals are computed as deltas from the last observed close:
last_day_close = x[:, -1, 3].numpy()
signals = {
'last': preds[:, -1, 3] - last_day_close,
'mean': np.mean(preds[:, :, 3], axis=1) - last_day_close,
'max': np.max(preds[:, :, 3], axis=1) - last_day_close,
'min': np.min(preds[:, :, 3], axis=1) - last_day_close,
}
Post-processing
Results are collected as (timestamp, symbol, score) tuples and pivoted into DataFrames:
for sig_type, records in results.items():
df = pd.DataFrame(records, columns=['datetime', 'instrument', 'score'])
pivot_df = df.pivot_table(index='datetime', columns='instrument', values='score')
prediction_dfs[sig_type] = pivot_df.sort_index()
Example Usage
import pickle
from qlib_test import generate_predictions
# Load test data
with open("./data/processed_datasets/test_data.pkl", 'rb') as f:
test_data = pickle.load(f)
# Configure inference
run_config = {
'device': 'cuda:0',
'tokenizer_path': './outputs/models/finetune_tokenizer_demo/checkpoints/best_model',
'model_path': './outputs/models/finetune_predictor_demo/checkpoints/best_model',
'max_context': 512,
'pred_len': 10,
'clip': 5.0,
'T': 0.6,
'top_k': 0,
'top_p': 0.9,
'sample_count': 5,
'batch_size': 1000,
}
# Generate prediction signals
prediction_dfs = generate_predictions(run_config, test_data)
# prediction_dfs keys: 'mean', 'last', 'max', 'min'
# Each value is a pd.DataFrame (datetime index x symbol columns)
Source Reference
File: finetune/qlib_test.py, lines 239-295 (generate_predictions), lines 32-89 (QlibTestDataset), lines 207-213 (load_models), lines 216-236 (collate_fn_for_inference).