Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Shiyu coder Kronos Auto Regressive Inference

From Leeroopedia


Field Value
implementation_name Auto_Regressive_Inference
repo Shiyu_coder_Kronos
type API Doc
source_file model/kronos.py:L389-469
function auto_regressive_inference (module-level function)
implements Principle:Shiyu_coder_Kronos_Autoregressive_Token_Generation
last_updated 2026-02-09 14:00 GMT

Summary

The auto_regressive_inference function is the core generation loop that produces hierarchical discrete tokens step-by-step using a sliding context window, then decodes them back to continuous values. It is called internally by both KronosPredictor.predict() and KronosPredictor.predict_batch().

API Signature

auto_regressive_inference(
    tokenizer: KronosTokenizer,
    model: Kronos,
    x: torch.Tensor,
    x_stamp: torch.Tensor,
    y_stamp: torch.Tensor,
    max_context: int,
    pred_len: int,
    clip: int = 5,
    T: float = 1.0,
    top_k: int = 0,
    top_p: float = 0.99,
    sample_count: int = 5,
    verbose: bool = False
) -> np.ndarray

Import

from model.kronos import auto_regressive_inference

Parameters

Parameter Type Default Description
tokenizer KronosTokenizer (required) The loaded VQ-VAE tokenizer for encoding/decoding.
model Kronos (required) The loaded autoregressive Transformer model.
x torch.Tensor (required) Input tensor of shape (batch, seq_len, features). Normalized continuous OHLCV data.
x_stamp torch.Tensor (required) Historical temporal features of shape (batch, seq_len, time_features).
y_stamp torch.Tensor (required) Future temporal features of shape (batch, pred_len, time_features).
max_context int (required) Maximum context window length for the sliding buffer.
pred_len int (required) Number of future tokens to generate.
clip int 5 Clipping bound applied to input tensor values.
T float 1.0 Sampling temperature for logit scaling.
top_k int 0 Top-k filtering threshold. 0 disables top-k.
top_p float 0.99 Top-p (nucleus sampling) threshold.
sample_count int 5 Number of parallel samples per batch item, averaged after decoding.
verbose bool False Whether to display a progress bar (uses trange).

Input

  • x (torch.Tensor): Normalized continuous data of shape (batch_size, seq_len, features). Already clipped to [-clip, clip].
  • x_stamp (torch.Tensor): Temporal features for historical timesteps.
  • y_stamp (torch.Tensor): Temporal features for future (prediction) timesteps.

Output

  • np.ndarray: Predicted normalized values of shape (batch_size, total_seq_len, features), where total_seq_len = seq_len + pred_len. The caller typically slices [:, -pred_len:, :] to extract only the predicted portion.

Algorithm

The function operates in four phases:

Phase 1: Preparation

# Clip input values
x = torch.clip(x, -clip, clip)

# Replicate batch for multi-sample averaging
x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, seq_len, features)
# Effective batch size becomes: original_batch * sample_count

# Encode historical data to hierarchical tokens
x_token = tokenizer.encode(x, half=True)  # Returns (s1_indices, s2_indices)

# Concatenate historical + future temporal features
full_stamp = torch.cat([x_stamp, y_stamp], dim=1)

Phase 2: Buffer Initialization

# Initialize sliding buffers of size max_context
pre_buffer = zeros(batch_size, max_context)   # s1 tokens
post_buffer = zeros(batch_size, max_context)  # s2 tokens

# Fill with historical tokens (up to max_context length)
buffer_len = min(initial_seq_len, max_context)
pre_buffer[:, :buffer_len] = x_token[0][:, -buffer_len:]
post_buffer[:, :buffer_len] = x_token[1][:, -buffer_len:]

Phase 3: Autoregressive Generation Loop

For each step i in range(pred_len):
    1. Select input tokens from buffer (up to window_len)
    2. Select corresponding temporal features
    3. model.decode_s1(s1_tokens, s2_tokens, stamps) -> s1_logits, context
    4. Sample s1 token: sample_from_logits(s1_logits[:, -1, :], T, top_k, top_p)
    5. model.decode_s2(context, sampled_s1) -> s2_logits
    6. Sample s2 token: sample_from_logits(s2_logits[:, -1, :], T, top_k, top_p)
    7. Store generated tokens
    8. Update sliding buffer (append or shift-and-append)

Phase 4: Decoding and Averaging

# Concatenate historical + generated tokens
full_pre = torch.cat([x_token[0], generated_pre], dim=1)
full_post = torch.cat([x_token[1], generated_post], dim=1)

# Decode back to continuous values (using last max_context tokens)
z = tokenizer.decode([full_pre[:, -max_context:], full_post[:, -max_context:]], half=True)

# Reshape and average across samples
z = z.reshape(-1, sample_count, seq_len, features)
preds = np.mean(z.cpu().numpy(), axis=1)

Source Code Reference

File: model/kronos.py, lines 389-469.

def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context,
                               pred_len, clip=5, T=1.0, top_k=0, top_p=0.99,
                               sample_count=5, verbose=False):
    with torch.no_grad():
        x = torch.clip(x, -clip, clip)
        device = x.device
        x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(...)
        x_token = tokenizer.encode(x, half=True)
        # ... sliding buffer generation loop ...
        z = tokenizer.decode(input_tokens, half=True)
        z = z.reshape(-1, sample_count, z.size(1), z.size(2))
        preds = np.mean(z.cpu().numpy(), axis=1)
        return preds

Notes

  • The entire function runs inside torch.no_grad() for inference efficiency.
  • The sliding buffer uses torch.roll() to shift tokens left when the buffer is full, maintaining O(1) memory per step.
  • This function is the core generation engine called by both predict() (single series) and predict_batch() (multiple series). It does not perform normalization or denormalization; those are handled by the caller.
  • The output includes both historical and predicted tokens decoded together, so the caller must slice [:, -pred_len:, :] to isolate predictions.
  • The sample_count replication happens along the batch dimension, so GPU memory usage scales linearly with sample_count.

Environment & Heuristic Links

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment