Implementation:Shiyu coder Kronos KronosPredictor Predict Batch
Appearance
| Field | Value |
|---|---|
| implementation_name | KronosPredictor_Predict_Batch |
| repo | Shiyu_coder_Kronos |
| type | API Doc |
| source_file | model/kronos.py:L562-661 |
| class | KronosPredictor |
| implements | Principle:Shiyu_coder_Kronos_Batch_Forecasting |
| last_updated | 2026-02-09 14:00 GMT |
Summary
The KronosPredictor.predict_batch method generates candlestick forecasts for multiple financial time series in parallel on the GPU, returning a list of DataFrames with predicted OHLCV+amount values.
API Signature
KronosPredictor.predict_batch(
df_list: List[pd.DataFrame],
x_timestamp_list: List[pd.DatetimeIndex],
y_timestamp_list: List[pd.DatetimeIndex],
pred_len: int,
T: float = 1.0,
top_k: int = 0,
top_p: float = 0.9,
sample_count: int = 1,
verbose: bool = True
) -> List[pd.DataFrame]
Import
from model import KronosPredictor
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
| df_list | List[pd.DataFrame] | (required) | List of historical OHLCV DataFrames, one per series. Each must contain open, high, low, close columns.
|
| x_timestamp_list | List[pd.DatetimeIndex] | (required) | List of historical timestamp indices, one per series. Length must match number of rows in corresponding DataFrame. |
| y_timestamp_list | List[pd.DatetimeIndex] | (required) | List of future timestamp indices, one per series. Length must equal pred_len.
|
| pred_len | int | (required) | Number of future timesteps to predict. |
| T | float | 1.0 | Sampling temperature. |
| top_k | int | 0 | Top-k filtering threshold. 0 disables top-k. |
| top_p | float | 0.9 | Top-p (nucleus sampling) threshold. |
| sample_count | int | 1 | Number of parallel samples per series, averaged internally. |
| verbose | bool | True | Whether to display autoregressive progress bar. |
Input
- df_list (List[pd.DataFrame]): List of N historical candlestick DataFrames. All must have the same number of rows (same historical length).
- x_timestamp_list (List[pd.DatetimeIndex]): List of N historical timestamp series.
- y_timestamp_list (List[pd.DatetimeIndex]): List of N future timestamp series. All must have the same length, equal to
pred_len.
Output
- List[pd.DataFrame]: List of N DataFrames in the same order as input. Each DataFrame has columns
[open, high, low, close, volume, amount]indexed by the corresponding entry fromy_timestamp_list.
Key Constraint
All series must have identical dimensions:
- All historical DataFrames must have the same number of rows (same historical length).
- All future timestamp lists must have the same length (same prediction length, equal to
pred_len).
If these constraints are violated, a ValueError is raised.
Internal Pipeline
1. Validate all inputs (types, columns, lengths)
2. For each series i:
a. Fill missing volume/amount columns
b. Compute temporal features via calc_time_stamps()
c. Instance-normalize: x = (x - mean) / (std + 1e-5), clip to [-clip, clip]
d. Store normalized data, temporal features, mean, std
3. Verify all series have consistent historical and prediction lengths
4. Stack into batch tensors:
x_batch: (N, seq_len, features)
x_stamp_batch: (N, seq_len, time_features)
y_stamp_batch: (N, pred_len, time_features)
5. Call self.generate() -> auto_regressive_inference()
Returns preds: (N, pred_len, features)
6. For each series i:
a. Denormalize: preds[i] * (std[i] + 1e-5) + mean[i]
b. Create DataFrame indexed by y_timestamp_list[i]
7. Return list of DataFrames
Example
import pandas as pd
from model import KronosTokenizer, Kronos, KronosPredictor
# Setup
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
predictor = KronosPredictor(model, tokenizer, device="cuda:0")
# Prepare multiple series (all same length)
df_list = [stock_a_df, stock_b_df, stock_c_df] # Each has 500 rows
x_timestamps = [
pd.date_range("2024-01-01 09:30", periods=500, freq="1min"),
pd.date_range("2024-01-01 09:30", periods=500, freq="1min"),
pd.date_range("2024-01-01 09:30", periods=500, freq="1min"),
]
y_timestamps = [
pd.date_range("2024-01-01 17:50", periods=120, freq="1min"),
pd.date_range("2024-01-01 17:50", periods=120, freq="1min"),
pd.date_range("2024-01-01 17:50", periods=120, freq="1min"),
]
# Batch predict
pred_dfs = predictor.predict_batch(
df_list=df_list,
x_timestamp_list=x_timestamps,
y_timestamp_list=y_timestamps,
pred_len=120,
T=1.0,
top_p=0.9,
sample_count=3
)
# pred_dfs is a list of 3 DataFrames
for i, pdf in enumerate(pred_dfs):
print(f"Stock {i}: {pdf.shape}")
print(pdf.head())
Source Code Reference
File: model/kronos.py, lines 562-661.
def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len,
T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True):
# Validate inputs
if not (len(df_list) == len(x_timestamp_list) == len(y_timestamp_list)):
raise ValueError("...")
# Per-series normalization
for i in range(num_series):
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x_norm = (x - x_mean) / (x_std + 1e-5)
x_norm = np.clip(x_norm, -self.clip, self.clip)
# ... store normalized data ...
# Verify consistent dimensions
if len(set(seq_lens)) != 1:
raise ValueError("Parallel prediction requires consistent historical lengths")
# Stack and generate
x_batch = np.stack(x_list, axis=0)
preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, ...)
# Per-series denormalization
for i in range(num_series):
preds_i = preds[i] * (stds[i] + 1e-5) + means[i]
pred_dfs.append(pd.DataFrame(preds_i, ...))
return pred_dfs
Error Handling
- Raises
ValueErrorifdf_list,x_timestamp_list, ory_timestamp_listare not list/tuple types. - Raises
ValueErrorif the three lists have inconsistent lengths. - Raises
ValueErrorif any individual DataFrame is not a pandas DataFrame or is missing required price columns. - Raises
ValueErrorif any DataFrame contains NaN values in price or volume columns. - Raises
ValueErrorif historical lengths are inconsistent across series. - Raises
ValueErrorif prediction lengths are inconsistent across series. - Raises
ValueErrorif anyy_timestamplength does not matchpred_len.
Notes
- Normalization and denormalization are performed per-series, not across the batch. Each series retains its own scale.
- GPU memory scales with
N_series * sample_count. For large batches with high sample counts, monitor GPU memory usage. - The method delegates to
self.generate()which callsauto_regressive_inference()with the stacked batch tensor.
Environment & Heuristic Links
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment