Implementation:Shiyu coder Kronos Auto Regressive Inference
Appearance
| Field | Value |
|---|---|
| implementation_name | Auto_Regressive_Inference |
| repo | Shiyu_coder_Kronos |
| type | API Doc |
| source_file | model/kronos.py:L389-469 |
| function | auto_regressive_inference (module-level function) |
| implements | Principle:Shiyu_coder_Kronos_Autoregressive_Token_Generation |
| last_updated | 2026-02-09 14:00 GMT |
Summary
The auto_regressive_inference function is the core generation loop that produces hierarchical discrete tokens step-by-step using a sliding context window, then decodes them back to continuous values. It is called internally by both KronosPredictor.predict() and KronosPredictor.predict_batch().
API Signature
auto_regressive_inference(
tokenizer: KronosTokenizer,
model: Kronos,
x: torch.Tensor,
x_stamp: torch.Tensor,
y_stamp: torch.Tensor,
max_context: int,
pred_len: int,
clip: int = 5,
T: float = 1.0,
top_k: int = 0,
top_p: float = 0.99,
sample_count: int = 5,
verbose: bool = False
) -> np.ndarray
Import
from model.kronos import auto_regressive_inference
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
| tokenizer | KronosTokenizer | (required) | The loaded VQ-VAE tokenizer for encoding/decoding. |
| model | Kronos | (required) | The loaded autoregressive Transformer model. |
| x | torch.Tensor | (required) | Input tensor of shape (batch, seq_len, features). Normalized continuous OHLCV data.
|
| x_stamp | torch.Tensor | (required) | Historical temporal features of shape (batch, seq_len, time_features).
|
| y_stamp | torch.Tensor | (required) | Future temporal features of shape (batch, pred_len, time_features).
|
| max_context | int | (required) | Maximum context window length for the sliding buffer. |
| pred_len | int | (required) | Number of future tokens to generate. |
| clip | int | 5 | Clipping bound applied to input tensor values. |
| T | float | 1.0 | Sampling temperature for logit scaling. |
| top_k | int | 0 | Top-k filtering threshold. 0 disables top-k. |
| top_p | float | 0.99 | Top-p (nucleus sampling) threshold. |
| sample_count | int | 5 | Number of parallel samples per batch item, averaged after decoding. |
| verbose | bool | False | Whether to display a progress bar (uses trange).
|
Input
- x (torch.Tensor): Normalized continuous data of shape
(batch_size, seq_len, features). Already clipped to [-clip, clip]. - x_stamp (torch.Tensor): Temporal features for historical timesteps.
- y_stamp (torch.Tensor): Temporal features for future (prediction) timesteps.
Output
- np.ndarray: Predicted normalized values of shape
(batch_size, total_seq_len, features), wheretotal_seq_len = seq_len + pred_len. The caller typically slices[:, -pred_len:, :]to extract only the predicted portion.
Algorithm
The function operates in four phases:
Phase 1: Preparation
# Clip input values
x = torch.clip(x, -clip, clip)
# Replicate batch for multi-sample averaging
x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, seq_len, features)
# Effective batch size becomes: original_batch * sample_count
# Encode historical data to hierarchical tokens
x_token = tokenizer.encode(x, half=True) # Returns (s1_indices, s2_indices)
# Concatenate historical + future temporal features
full_stamp = torch.cat([x_stamp, y_stamp], dim=1)
Phase 2: Buffer Initialization
# Initialize sliding buffers of size max_context
pre_buffer = zeros(batch_size, max_context) # s1 tokens
post_buffer = zeros(batch_size, max_context) # s2 tokens
# Fill with historical tokens (up to max_context length)
buffer_len = min(initial_seq_len, max_context)
pre_buffer[:, :buffer_len] = x_token[0][:, -buffer_len:]
post_buffer[:, :buffer_len] = x_token[1][:, -buffer_len:]
Phase 3: Autoregressive Generation Loop
For each step i in range(pred_len):
1. Select input tokens from buffer (up to window_len)
2. Select corresponding temporal features
3. model.decode_s1(s1_tokens, s2_tokens, stamps) -> s1_logits, context
4. Sample s1 token: sample_from_logits(s1_logits[:, -1, :], T, top_k, top_p)
5. model.decode_s2(context, sampled_s1) -> s2_logits
6. Sample s2 token: sample_from_logits(s2_logits[:, -1, :], T, top_k, top_p)
7. Store generated tokens
8. Update sliding buffer (append or shift-and-append)
Phase 4: Decoding and Averaging
# Concatenate historical + generated tokens
full_pre = torch.cat([x_token[0], generated_pre], dim=1)
full_post = torch.cat([x_token[1], generated_post], dim=1)
# Decode back to continuous values (using last max_context tokens)
z = tokenizer.decode([full_pre[:, -max_context:], full_post[:, -max_context:]], half=True)
# Reshape and average across samples
z = z.reshape(-1, sample_count, seq_len, features)
preds = np.mean(z.cpu().numpy(), axis=1)
Source Code Reference
File: model/kronos.py, lines 389-469.
def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context,
pred_len, clip=5, T=1.0, top_k=0, top_p=0.99,
sample_count=5, verbose=False):
with torch.no_grad():
x = torch.clip(x, -clip, clip)
device = x.device
x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(...)
x_token = tokenizer.encode(x, half=True)
# ... sliding buffer generation loop ...
z = tokenizer.decode(input_tokens, half=True)
z = z.reshape(-1, sample_count, z.size(1), z.size(2))
preds = np.mean(z.cpu().numpy(), axis=1)
return preds
Notes
- The entire function runs inside
torch.no_grad()for inference efficiency. - The sliding buffer uses
torch.roll()to shift tokens left when the buffer is full, maintaining O(1) memory per step. - This function is the core generation engine called by both
predict()(single series) andpredict_batch()(multiple series). It does not perform normalization or denormalization; those are handled by the caller. - The output includes both historical and predicted tokens decoded together, so the caller must slice
[:, -pred_len:, :]to isolate predictions. - The
sample_countreplication happens along the batch dimension, so GPU memory usage scales linearly withsample_count.
Environment & Heuristic Links
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment