Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Shiyu coder Kronos Plot Prediction Pattern

From Leeroopedia


Field Value
Implementation Name Plot_Prediction_Pattern
Repository Shiyu_coder_Kronos
Repository URL https://github.com/shiyu-coder/Kronos
Type Pattern Doc
Source File examples/prediction_example.py
Lines L8-38 (plot_prediction function)
Implements Principle Principle:Shiyu_coder_Kronos_Prediction_Visualization
Dependencies pandas, matplotlib
Last Updated 2026-02-09 14:00 GMT

Overview

This is a user-defined pattern (not a library API) that demonstrates how to visualize Kronos prediction results by overlaying predicted close price and volume against ground truth data using matplotlib.

Pattern Interface

def plot_prediction(kline_df: pd.DataFrame, pred_df: pd.DataFrame) -> None:
    """
    Overlay predicted values on ground truth for visual forecast evaluation.

    Args:
        kline_df: Full DataFrame with ground truth OHLCV data (historical + future).
                  Must include 'close' and 'volume' columns.
                  Shape: (lookback + pred_len, N) where N >= 6 columns.
        pred_df:  Prediction DataFrame returned by KronosPredictor.predict().
                  Must include 'close' and 'volume' columns.
                  Shape: (pred_len, 6) with columns [open, high, low, close, volume, amount].

    Returns:
        None. Displays a matplotlib figure with 2 subplots.
    """

Input

Parameter Type Description
kline_df pd.DataFrame Full ground truth data covering both the historical lookback window and the future prediction horizon. Must contain close and volume columns.
pred_df pd.DataFrame Prediction output from KronosPredictor.predict(). Must contain close and volume columns.

Output

A matplotlib figure with two subplots:

  • Top subplot: Close price -- ground truth (blue) overlaid with prediction (red)
  • Bottom subplot: Volume -- ground truth (blue) overlaid with prediction (red)

The figure size is (8, 6) with shared x-axis, grid lines enabled, and tight layout.

Full Implementation from Source

import pandas as pd
import matplotlib.pyplot as plt

def plot_prediction(kline_df, pred_df):
    # Align prediction index to ground truth tail
    pred_df.index = kline_df.index[-pred_df.shape[0]:]

    # Extract close price series
    sr_close = kline_df['close']
    sr_pred_close = pred_df['close']
    sr_close.name = 'Ground Truth'
    sr_pred_close.name = "Prediction"

    # Extract volume series
    sr_volume = kline_df['volume']
    sr_pred_volume = pred_df['volume']
    sr_volume.name = 'Ground Truth'
    sr_pred_volume.name = "Prediction"

    # Concatenate for overlay plotting
    close_df = pd.concat([sr_close, sr_pred_close], axis=1)
    volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1)

    # Create figure with two subplots sharing x-axis
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True)

    # Close price subplot
    ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
    ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
    ax1.set_ylabel('Close Price', fontsize=14)
    ax1.legend(loc='lower left', fontsize=12)
    ax1.grid(True)

    # Volume subplot
    ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
    ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
    ax2.set_ylabel('Volume', fontsize=14)
    ax2.legend(loc='upper left', fontsize=12)
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

Usage Example

Complete end-to-end usage with data preparation and prediction:

import pandas as pd
from model import Kronos, KronosTokenizer, KronosPredictor

# Load model
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
predictor = KronosPredictor(model, tokenizer, max_context=512)

# Prepare data
df = pd.read_csv("./data/XSHG_5min_600977.csv")
df['timestamps'] = pd.to_datetime(df['timestamps'])

lookback = 400
pred_len = 120

x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close', 'volume', 'amount']]
x_timestamp = df.loc[:lookback-1, 'timestamps']
y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps']

# Generate predictions
pred_df = predictor.predict(
    df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp,
    pred_len=pred_len, T=1.0, top_p=0.9, sample_count=1, verbose=True
)

# Combine historical and future ground truth for the full plot
kline_df = df.loc[:lookback+pred_len-1]

# Visualize
plot_prediction(kline_df, pred_df)

Key Implementation Details

  • Index alignment: pred_df.index = kline_df.index[-pred_df.shape[0]:] maps prediction rows to the corresponding ground truth time positions.
  • Concatenation trick: Using pd.concat([ground_truth, prediction], axis=1) creates a DataFrame where NaN values naturally appear where data does not overlap, allowing matplotlib to render clean separate line segments.
  • Shared x-axis: sharex=True ensures both subplots align temporally for synchronized comparison.
  • Color convention: Blue for ground truth, red for predictions -- consistent across both subplots.

See Also

Environment & Heuristic Links

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment