Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Shiyu coder Kronos KronosPredictor Predict Batch

From Leeroopedia


Field Value
implementation_name KronosPredictor_Predict_Batch
repo Shiyu_coder_Kronos
type API Doc
source_file model/kronos.py:L562-661
class KronosPredictor
implements Principle:Shiyu_coder_Kronos_Batch_Forecasting
last_updated 2026-02-09 14:00 GMT

Summary

The KronosPredictor.predict_batch method generates candlestick forecasts for multiple financial time series in parallel on the GPU, returning a list of DataFrames with predicted OHLCV+amount values.

API Signature

KronosPredictor.predict_batch(
    df_list: List[pd.DataFrame],
    x_timestamp_list: List[pd.DatetimeIndex],
    y_timestamp_list: List[pd.DatetimeIndex],
    pred_len: int,
    T: float = 1.0,
    top_k: int = 0,
    top_p: float = 0.9,
    sample_count: int = 1,
    verbose: bool = True
) -> List[pd.DataFrame]

Import

from model import KronosPredictor

Parameters

Parameter Type Default Description
df_list List[pd.DataFrame] (required) List of historical OHLCV DataFrames, one per series. Each must contain open, high, low, close columns.
x_timestamp_list List[pd.DatetimeIndex] (required) List of historical timestamp indices, one per series. Length must match number of rows in corresponding DataFrame.
y_timestamp_list List[pd.DatetimeIndex] (required) List of future timestamp indices, one per series. Length must equal pred_len.
pred_len int (required) Number of future timesteps to predict.
T float 1.0 Sampling temperature.
top_k int 0 Top-k filtering threshold. 0 disables top-k.
top_p float 0.9 Top-p (nucleus sampling) threshold.
sample_count int 1 Number of parallel samples per series, averaged internally.
verbose bool True Whether to display autoregressive progress bar.

Input

  • df_list (List[pd.DataFrame]): List of N historical candlestick DataFrames. All must have the same number of rows (same historical length).
  • x_timestamp_list (List[pd.DatetimeIndex]): List of N historical timestamp series.
  • y_timestamp_list (List[pd.DatetimeIndex]): List of N future timestamp series. All must have the same length, equal to pred_len.

Output

  • List[pd.DataFrame]: List of N DataFrames in the same order as input. Each DataFrame has columns [open, high, low, close, volume, amount] indexed by the corresponding entry from y_timestamp_list.

Key Constraint

All series must have identical dimensions:

  • All historical DataFrames must have the same number of rows (same historical length).
  • All future timestamp lists must have the same length (same prediction length, equal to pred_len).

If these constraints are violated, a ValueError is raised.

Internal Pipeline

1. Validate all inputs (types, columns, lengths)
2. For each series i:
   a. Fill missing volume/amount columns
   b. Compute temporal features via calc_time_stamps()
   c. Instance-normalize: x = (x - mean) / (std + 1e-5), clip to [-clip, clip]
   d. Store normalized data, temporal features, mean, std
3. Verify all series have consistent historical and prediction lengths
4. Stack into batch tensors:
   x_batch:       (N, seq_len, features)
   x_stamp_batch: (N, seq_len, time_features)
   y_stamp_batch: (N, pred_len, time_features)
5. Call self.generate() -> auto_regressive_inference()
   Returns preds: (N, pred_len, features)
6. For each series i:
   a. Denormalize: preds[i] * (std[i] + 1e-5) + mean[i]
   b. Create DataFrame indexed by y_timestamp_list[i]
7. Return list of DataFrames

Example

import pandas as pd
from model import KronosTokenizer, Kronos, KronosPredictor

# Setup
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
predictor = KronosPredictor(model, tokenizer, device="cuda:0")

# Prepare multiple series (all same length)
df_list = [stock_a_df, stock_b_df, stock_c_df]  # Each has 500 rows
x_timestamps = [
    pd.date_range("2024-01-01 09:30", periods=500, freq="1min"),
    pd.date_range("2024-01-01 09:30", periods=500, freq="1min"),
    pd.date_range("2024-01-01 09:30", periods=500, freq="1min"),
]
y_timestamps = [
    pd.date_range("2024-01-01 17:50", periods=120, freq="1min"),
    pd.date_range("2024-01-01 17:50", periods=120, freq="1min"),
    pd.date_range("2024-01-01 17:50", periods=120, freq="1min"),
]

# Batch predict
pred_dfs = predictor.predict_batch(
    df_list=df_list,
    x_timestamp_list=x_timestamps,
    y_timestamp_list=y_timestamps,
    pred_len=120,
    T=1.0,
    top_p=0.9,
    sample_count=3
)

# pred_dfs is a list of 3 DataFrames
for i, pdf in enumerate(pred_dfs):
    print(f"Stock {i}: {pdf.shape}")
    print(pdf.head())

Source Code Reference

File: model/kronos.py, lines 562-661.

def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len,
                  T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True):
    # Validate inputs
    if not (len(df_list) == len(x_timestamp_list) == len(y_timestamp_list)):
        raise ValueError("...")

    # Per-series normalization
    for i in range(num_series):
        x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
        x_norm = (x - x_mean) / (x_std + 1e-5)
        x_norm = np.clip(x_norm, -self.clip, self.clip)
        # ... store normalized data ...

    # Verify consistent dimensions
    if len(set(seq_lens)) != 1:
        raise ValueError("Parallel prediction requires consistent historical lengths")

    # Stack and generate
    x_batch = np.stack(x_list, axis=0)
    preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, ...)

    # Per-series denormalization
    for i in range(num_series):
        preds_i = preds[i] * (stds[i] + 1e-5) + means[i]
        pred_dfs.append(pd.DataFrame(preds_i, ...))

    return pred_dfs

Error Handling

  • Raises ValueError if df_list, x_timestamp_list, or y_timestamp_list are not list/tuple types.
  • Raises ValueError if the three lists have inconsistent lengths.
  • Raises ValueError if any individual DataFrame is not a pandas DataFrame or is missing required price columns.
  • Raises ValueError if any DataFrame contains NaN values in price or volume columns.
  • Raises ValueError if historical lengths are inconsistent across series.
  • Raises ValueError if prediction lengths are inconsistent across series.
  • Raises ValueError if any y_timestamp length does not match pred_len.

Notes

  • Normalization and denormalization are performed per-series, not across the batch. Each series retains its own scale.
  • GPU memory scales with N_series * sample_count. For large batches with high sample counts, monitor GPU memory usage.
  • The method delegates to self.generate() which calls auto_regressive_inference() with the stacked batch tensor.

Environment & Heuristic Links

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment