Implementation:Shiyu coder Kronos Plot Prediction Pattern
Appearance
| Field | Value |
|---|---|
| Implementation Name | Plot_Prediction_Pattern |
| Repository | Shiyu_coder_Kronos |
| Repository URL | https://github.com/shiyu-coder/Kronos |
| Type | Pattern Doc |
| Source File | examples/prediction_example.py |
| Lines | L8-38 (plot_prediction function) |
| Implements Principle | Principle:Shiyu_coder_Kronos_Prediction_Visualization |
| Dependencies | pandas, matplotlib |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
This is a user-defined pattern (not a library API) that demonstrates how to visualize Kronos prediction results by overlaying predicted close price and volume against ground truth data using matplotlib.
Pattern Interface
def plot_prediction(kline_df: pd.DataFrame, pred_df: pd.DataFrame) -> None:
"""
Overlay predicted values on ground truth for visual forecast evaluation.
Args:
kline_df: Full DataFrame with ground truth OHLCV data (historical + future).
Must include 'close' and 'volume' columns.
Shape: (lookback + pred_len, N) where N >= 6 columns.
pred_df: Prediction DataFrame returned by KronosPredictor.predict().
Must include 'close' and 'volume' columns.
Shape: (pred_len, 6) with columns [open, high, low, close, volume, amount].
Returns:
None. Displays a matplotlib figure with 2 subplots.
"""
Input
| Parameter | Type | Description |
|---|---|---|
| kline_df | pd.DataFrame | Full ground truth data covering both the historical lookback window and the future prediction horizon. Must contain close and volume columns.
|
| pred_df | pd.DataFrame | Prediction output from KronosPredictor.predict(). Must contain close and volume columns.
|
Output
A matplotlib figure with two subplots:
- Top subplot: Close price -- ground truth (blue) overlaid with prediction (red)
- Bottom subplot: Volume -- ground truth (blue) overlaid with prediction (red)
The figure size is (8, 6) with shared x-axis, grid lines enabled, and tight layout.
Full Implementation from Source
import pandas as pd
import matplotlib.pyplot as plt
def plot_prediction(kline_df, pred_df):
# Align prediction index to ground truth tail
pred_df.index = kline_df.index[-pred_df.shape[0]:]
# Extract close price series
sr_close = kline_df['close']
sr_pred_close = pred_df['close']
sr_close.name = 'Ground Truth'
sr_pred_close.name = "Prediction"
# Extract volume series
sr_volume = kline_df['volume']
sr_pred_volume = pred_df['volume']
sr_volume.name = 'Ground Truth'
sr_pred_volume.name = "Prediction"
# Concatenate for overlay plotting
close_df = pd.concat([sr_close, sr_pred_close], axis=1)
volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1)
# Create figure with two subplots sharing x-axis
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
# Close price subplot
ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
ax1.set_ylabel('Close Price', fontsize=14)
ax1.legend(loc='lower left', fontsize=12)
ax1.grid(True)
# Volume subplot
ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
ax2.set_ylabel('Volume', fontsize=14)
ax2.legend(loc='upper left', fontsize=12)
ax2.grid(True)
plt.tight_layout()
plt.show()
Usage Example
Complete end-to-end usage with data preparation and prediction:
import pandas as pd
from model import Kronos, KronosTokenizer, KronosPredictor
# Load model
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
predictor = KronosPredictor(model, tokenizer, max_context=512)
# Prepare data
df = pd.read_csv("./data/XSHG_5min_600977.csv")
df['timestamps'] = pd.to_datetime(df['timestamps'])
lookback = 400
pred_len = 120
x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close', 'volume', 'amount']]
x_timestamp = df.loc[:lookback-1, 'timestamps']
y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps']
# Generate predictions
pred_df = predictor.predict(
df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp,
pred_len=pred_len, T=1.0, top_p=0.9, sample_count=1, verbose=True
)
# Combine historical and future ground truth for the full plot
kline_df = df.loc[:lookback+pred_len-1]
# Visualize
plot_prediction(kline_df, pred_df)
Key Implementation Details
- Index alignment:
pred_df.index = kline_df.index[-pred_df.shape[0]:]maps prediction rows to the corresponding ground truth time positions. - Concatenation trick: Using
pd.concat([ground_truth, prediction], axis=1)creates a DataFrame where NaN values naturally appear where data does not overlap, allowing matplotlib to render clean separate line segments. - Shared x-axis:
sharex=Trueensures both subplots align temporally for synchronized comparison. - Color convention: Blue for ground truth, red for predictions -- consistent across both subplots.
See Also
- Principle:Shiyu_coder_Kronos_Prediction_Visualization -- The principle this pattern implements
- Implementation:Shiyu_coder_Kronos_Candlestick_Data_Preparation_Pattern -- Data preparation that precedes visualization
Environment & Heuristic Links
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment