Implementation:Shiyu coder Kronos KronosPredictor Predict
Appearance
| Field | Value |
|---|---|
| implementation_name | KronosPredictor_Predict |
| repo | Shiyu_coder_Kronos |
| type | API Doc |
| source_file | model/kronos.py:L519-559 |
| class | KronosPredictor |
| implements | Principle:Shiyu_coder_Kronos_Single_Series_Forecasting |
| last_updated | 2026-02-09 14:00 GMT |
Summary
The KronosPredictor.predict method generates probabilistic candlestick forecasts for a single financial time series, returning a DataFrame of predicted OHLCV+amount values indexed by future timestamps.
API Signature
KronosPredictor.predict(
df: pd.DataFrame,
x_timestamp: pd.Series,
y_timestamp: pd.Series,
pred_len: int,
T: float = 1.0,
top_k: int = 0,
top_p: float = 0.9,
sample_count: int = 1,
verbose: bool = True
) -> pd.DataFrame
Import
from model import KronosPredictor
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
| df | pd.DataFrame | (required) | Historical OHLCV data. Must contain columns: open, high, low, close. Optionally volume and amount.
|
| x_timestamp | pd.Series / pd.DatetimeIndex | (required) | Timestamps corresponding to the historical data rows. Used to compute temporal features (minute, hour, weekday, day, month). |
| y_timestamp | pd.Series / pd.DatetimeIndex | (required) | Future timestamps for the prediction horizon. Used as temporal features during generation and as the output DataFrame index. |
| pred_len | int | (required) | Number of future timesteps to predict. |
| T | float | 1.0 | Sampling temperature. Higher values increase randomness; lower values make predictions more deterministic. |
| top_k | int | 0 | Top-k filtering threshold. 0 disables top-k filtering. |
| top_p | float | 0.9 | Top-p (nucleus sampling) threshold. Keeps tokens with cumulative probability >= top_p. |
| sample_count | int | 1 | Number of parallel samples to generate and average for more stable predictions. |
| verbose | bool | True | Whether to display a progress bar during autoregressive generation. |
Input
- df (pd.DataFrame): Historical candlestick data with at minimum the columns
open,high,low,close. Ifvolumeis missing, it is filled with zeros. Ifamountis missing, it is computed asvolume * mean(price_cols). - x_timestamp (pd.DatetimeIndex): Datetime index for historical data, used to extract temporal features.
- y_timestamp (pd.DatetimeIndex): Datetime index for the prediction horizon.
Output
- pd.DataFrame: A DataFrame with columns
[open, high, low, close, volume, amount]indexed byy_timestamp. Contains the predicted candlestick values in the original price scale.
Internal Pipeline
1. Validate input DataFrame columns and NaN values
2. Fill missing volume/amount columns
3. Compute temporal features via calc_time_stamps()
4. Instance-normalize: x = (x - mean) / (std + 1e-5), then clip to [-clip, clip]
5. Add batch dimension: x[np.newaxis, :]
6. Call self.generate() -> auto_regressive_inference()
7. Remove batch dimension: preds.squeeze(0)
8. Denormalize: preds = preds * (std + 1e-5) + mean
9. Return as pd.DataFrame indexed by y_timestamp
Example
import pandas as pd
from model import KronosTokenizer, Kronos, KronosPredictor
# Setup
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
predictor = KronosPredictor(model, tokenizer, device="cuda:0")
# Prepare data
x_df = pd.DataFrame({
'open': [...], 'high': [...], 'low': [...], 'close': [...],
'volume': [...], 'amount': [...]
})
x_timestamp = pd.date_range("2024-01-01 09:30", periods=len(x_df), freq="1min")
y_timestamp = pd.date_range("2024-01-01 11:30", periods=120, freq="1min")
# Predict
pred_df = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=120,
T=1.0,
top_p=0.9,
sample_count=5
)
# pred_df has columns: open, high, low, close, volume, amount
# pred_df.index is y_timestamp
print(pred_df.head())
Source Code Reference
File: model/kronos.py, lines 519-559.
def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True):
if not isinstance(df, pd.DataFrame):
raise ValueError("Input must be a pandas DataFrame.")
if not all(col in df.columns for col in self.price_cols):
raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.")
df = df.copy()
# ... fill missing volume/amount ...
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x = (x - x_mean) / (x_std + 1e-5)
x = np.clip(x, -self.clip, self.clip)
preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose)
preds = preds.squeeze(0)
preds = preds * (x_std + 1e-5) + x_mean
pred_df = pd.DataFrame(preds, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp)
return pred_df
Error Handling
- Raises
ValueErrorif input is not a pandas DataFrame. - Raises
ValueErrorif required price columns (open,high,low,close) are missing. - Raises
ValueErrorif NaN values are found in price or volume columns.
Notes
- The method creates a copy of the input DataFrame to avoid modifying the original.
- Normalization uses per-column mean and standard deviation (instance normalization).
- The
sample_countparameter controls how many generation runs are averaged. Higher values give more stable results but take proportionally longer. - The method delegates to
self.generate(), which internally callsauto_regressive_inference().
Environment & Heuristic Links
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment