Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Shiyu coder Kronos KronosPredictor Predict

From Leeroopedia


Field Value
implementation_name KronosPredictor_Predict
repo Shiyu_coder_Kronos
type API Doc
source_file model/kronos.py:L519-559
class KronosPredictor
implements Principle:Shiyu_coder_Kronos_Single_Series_Forecasting
last_updated 2026-02-09 14:00 GMT

Summary

The KronosPredictor.predict method generates probabilistic candlestick forecasts for a single financial time series, returning a DataFrame of predicted OHLCV+amount values indexed by future timestamps.

API Signature

KronosPredictor.predict(
    df: pd.DataFrame,
    x_timestamp: pd.Series,
    y_timestamp: pd.Series,
    pred_len: int,
    T: float = 1.0,
    top_k: int = 0,
    top_p: float = 0.9,
    sample_count: int = 1,
    verbose: bool = True
) -> pd.DataFrame

Import

from model import KronosPredictor

Parameters

Parameter Type Default Description
df pd.DataFrame (required) Historical OHLCV data. Must contain columns: open, high, low, close. Optionally volume and amount.
x_timestamp pd.Series / pd.DatetimeIndex (required) Timestamps corresponding to the historical data rows. Used to compute temporal features (minute, hour, weekday, day, month).
y_timestamp pd.Series / pd.DatetimeIndex (required) Future timestamps for the prediction horizon. Used as temporal features during generation and as the output DataFrame index.
pred_len int (required) Number of future timesteps to predict.
T float 1.0 Sampling temperature. Higher values increase randomness; lower values make predictions more deterministic.
top_k int 0 Top-k filtering threshold. 0 disables top-k filtering.
top_p float 0.9 Top-p (nucleus sampling) threshold. Keeps tokens with cumulative probability >= top_p.
sample_count int 1 Number of parallel samples to generate and average for more stable predictions.
verbose bool True Whether to display a progress bar during autoregressive generation.

Input

  • df (pd.DataFrame): Historical candlestick data with at minimum the columns open, high, low, close. If volume is missing, it is filled with zeros. If amount is missing, it is computed as volume * mean(price_cols).
  • x_timestamp (pd.DatetimeIndex): Datetime index for historical data, used to extract temporal features.
  • y_timestamp (pd.DatetimeIndex): Datetime index for the prediction horizon.

Output

  • pd.DataFrame: A DataFrame with columns [open, high, low, close, volume, amount] indexed by y_timestamp. Contains the predicted candlestick values in the original price scale.

Internal Pipeline

1. Validate input DataFrame columns and NaN values
2. Fill missing volume/amount columns
3. Compute temporal features via calc_time_stamps()
4. Instance-normalize: x = (x - mean) / (std + 1e-5), then clip to [-clip, clip]
5. Add batch dimension: x[np.newaxis, :]
6. Call self.generate() -> auto_regressive_inference()
7. Remove batch dimension: preds.squeeze(0)
8. Denormalize: preds = preds * (std + 1e-5) + mean
9. Return as pd.DataFrame indexed by y_timestamp

Example

import pandas as pd
from model import KronosTokenizer, Kronos, KronosPredictor

# Setup
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
predictor = KronosPredictor(model, tokenizer, device="cuda:0")

# Prepare data
x_df = pd.DataFrame({
    'open': [...], 'high': [...], 'low': [...], 'close': [...],
    'volume': [...], 'amount': [...]
})
x_timestamp = pd.date_range("2024-01-01 09:30", periods=len(x_df), freq="1min")
y_timestamp = pd.date_range("2024-01-01 11:30", periods=120, freq="1min")

# Predict
pred_df = predictor.predict(
    df=x_df,
    x_timestamp=x_timestamp,
    y_timestamp=y_timestamp,
    pred_len=120,
    T=1.0,
    top_p=0.9,
    sample_count=5
)

# pred_df has columns: open, high, low, close, volume, amount
# pred_df.index is y_timestamp
print(pred_df.head())

Source Code Reference

File: model/kronos.py, lines 519-559.

def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True):
    if not isinstance(df, pd.DataFrame):
        raise ValueError("Input must be a pandas DataFrame.")
    if not all(col in df.columns for col in self.price_cols):
        raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.")

    df = df.copy()
    # ... fill missing volume/amount ...

    x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
    x = (x - x_mean) / (x_std + 1e-5)
    x = np.clip(x, -self.clip, self.clip)

    preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose)
    preds = preds.squeeze(0)
    preds = preds * (x_std + 1e-5) + x_mean

    pred_df = pd.DataFrame(preds, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp)
    return pred_df

Error Handling

  • Raises ValueError if input is not a pandas DataFrame.
  • Raises ValueError if required price columns (open, high, low, close) are missing.
  • Raises ValueError if NaN values are found in price or volume columns.

Notes

  • The method creates a copy of the input DataFrame to avoid modifying the original.
  • Normalization uses per-column mean and standard deviation (instance normalization).
  • The sample_count parameter controls how many generation runs are averaged. Higher values give more stable results but take proportionally longer.
  • The method delegates to self.generate(), which internally calls auto_regressive_inference().

Environment & Heuristic Links

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment