Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Shiyu coder Kronos KronosPredictor Init

From Leeroopedia


Field Value
implementation_name KronosPredictor_Init
repo Shiyu_coder_Kronos
type API Doc
source_file model/kronos.py:L482-506
class KronosPredictor
implements Principle:Shiyu_coder_Kronos_Predictor_Initialization
last_updated 2026-02-09 14:00 GMT

Summary

The KronosPredictor constructor wraps a loaded Kronos model and KronosTokenizer into a unified prediction interface, automatically detecting the best available compute device and moving both models to it.

API Signature

KronosPredictor(
    model: Kronos,
    tokenizer: KronosTokenizer,
    device: str = None,
    max_context: int = 512,
    clip: int = 5
) -> KronosPredictor

Import

from model import KronosPredictor
# or
from model.kronos import KronosPredictor

Parameters

Parameter Type Default Description
model Kronos (required) A loaded Kronos autoregressive model (from Kronos.from_pretrained()).
tokenizer KronosTokenizer (required) A loaded KronosTokenizer (from KronosTokenizer.from_pretrained()).
device str None Target device string (e.g., "cuda:0", "mps", "cpu"). Auto-detected if None.
max_context int 512 Maximum context window length for the Transformer during autoregressive generation.
clip int 5 Clipping bound for normalized input values. Values are clipped to [-clip, clip].

Input

  • model: A pre-trained Kronos nn.Module, typically loaded via Kronos.from_pretrained().
  • tokenizer: A pre-trained KronosTokenizer nn.Module, typically loaded via KronosTokenizer.from_pretrained().

Output

  • KronosPredictor: An initialized predictor instance with both models moved to the target device, ready for predict() or predict_batch() calls.

Device Auto-Detection

When device=None, the following priority order is used:

if torch.cuda.is_available():
    device = "cuda:0"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

Internal Attributes

After initialization, the predictor stores these internal attributes:

Attribute Value
self.tokenizer KronosTokenizer moved to target device
self.model Kronos model moved to target device
self.device Target device string
self.max_context Maximum context window length
self.clip Clipping bound
self.price_cols ['open', 'high', 'low', 'close']
self.vol_col 'volume'
self.amt_vol 'amount'
self.time_cols ['minute', 'hour', 'weekday', 'day', 'month']

Example

from model import KronosTokenizer, Kronos, KronosPredictor

# Load components
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")

# Create predictor with auto-detected device
predictor = KronosPredictor(model, tokenizer)

# Create predictor with explicit device
predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512)

Source Code Reference

File: model/kronos.py, lines 482-506.

class KronosPredictor:

    def __init__(self, model, tokenizer, device=None, max_context=512, clip=5):
        self.tokenizer = tokenizer
        self.model = model
        self.max_context = max_context
        self.clip = clip
        self.price_cols = ['open', 'high', 'low', 'close']
        self.vol_col = 'volume'
        self.amt_vol = 'amount'
        self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month']

        # Auto-detect device if not specified
        if device is None:
            if torch.cuda.is_available():
                device = "cuda:0"
            elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
                device = "mps"
            else:
                device = "cpu"

        self.device = device
        self.tokenizer = self.tokenizer.to(self.device)
        self.model = self.model.to(self.device)

Notes

  • Both the model and tokenizer are moved to the target device during __init__. There is no need to manually call .to(device) after creating the predictor.
  • The predictor does not call .eval() on the models. Users should ensure models are in eval mode if dropout should be disabled during inference.
  • The max_context parameter controls the sliding window size during autoregressive generation. Longer contexts use more memory but may capture longer-range dependencies.

Environment & Heuristic Links

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment