Heuristic:Shiyu coder Kronos Sampling Temperature Tuning
| Knowledge Sources | |
|---|---|
| Domains | Inference, LLMs, Time_Series |
| Last Updated | 2026-02-09 13:47 GMT |
Overview
Guidelines for tuning temperature, top_k, top_p, and sample_count parameters to control stochasticity and stability in Kronos autoregressive forecasting.
Description
Kronos uses stochastic sampling from discrete token distributions during autoregressive inference. The sampling behavior is controlled by four parameters: temperature (controls randomness), top_k (limits to k highest probability tokens), top_p (nucleus sampling threshold), and sample_count (number of parallel sample paths averaged). The codebase uses two distinct parameter profiles: a general inference profile (T=1.0, top_p=0.9, sample_count=1) and a backtesting profile (T=0.6, top_p=0.9, sample_count=5) that trades diversity for stability.
Usage
Use this heuristic when:
- Exploratory prediction: Use default T=1.0 for diverse forecast scenarios
- Trading backtests: Lower temperature (T=0.6) and increase sample_count (5) for more stable, averaged predictions
- Debugging: Set `sample_logits=False` for greedy (deterministic) decoding to isolate model behavior from sampling noise
The Insight (Rule of Thumb)
- Action: Adjust temperature and sample_count based on use case.
- General inference defaults: `T=1.0`, `top_p=0.9`, `top_k=0` (disabled), `sample_count=1`
- Backtesting defaults: `T=0.6`, `top_p=0.9`, `top_k=0`, `sample_count=5`
- Trade-off: Lower temperature produces more deterministic but potentially overconfident predictions. Higher sample_count improves stability but increases compute linearly (each sample runs the full autoregressive loop).
- Key relationship: `sample_count` paths are generated in parallel and then averaged in output space (not token space) to produce the final prediction.
Reasoning
Financial forecasting requires balancing exploration (capturing multiple scenarios) with stability (reliable signals for trading). The two-profile approach reflects this:
Temperature = 1.0 (general): Standard softmax sampling that preserves the model's learned distribution. Suitable for visualization and understanding forecast uncertainty.
Temperature = 0.6 (backtesting): Sharpens the distribution, making the model more likely to predict high-probability tokens. Combined with averaging 5 sample paths, this produces smoother, more reliable trading signals. The 0.6 value was likely tuned empirically to balance signal quality vs. overconfidence.
Nucleus sampling (top_p=0.9): Filters out the lowest 10% probability tokens, preventing rare token artifacts while preserving meaningful diversity.
Evidence from default inference parameters in `model/kronos.py:519`:
def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True):
Backtesting inference parameters from `finetune/config.py:115-118`:
self.inference_T = 0.6
self.inference_top_p = 0.9
self.inference_top_k = 0
self.inference_sample_count = 5
Sampling implementation from `model/kronos.py:373-386`:
def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True):
logits = logits / temperature
if top_k is not None or top_p is not None:
if top_k > 0 or top_p < 1.0:
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=-1)
if not sample_logits:
_, x = top_k(probs, k=1, dim=-1)
else:
x = torch.multinomial(probs, num_samples=1)
return x
Sample averaging in output space from `model/kronos.py:465-467`:
z = z.reshape(-1, sample_count, z.size(1), z.size(2))
preds = z.cpu().numpy()
preds = np.mean(preds, axis=1)
Related Pages
- Implementation:Shiyu_coder_Kronos_Auto_Regressive_Inference
- Implementation:Shiyu_coder_Kronos_KronosPredictor_Predict
- Implementation:Shiyu_coder_Kronos_KronosPredictor_Predict_Batch
- Implementation:Shiyu_coder_Kronos_Generate_Predictions_Qlib
- Implementation:Shiyu_coder_Kronos_WebUI_App
- Principle:Shiyu_coder_Kronos_Autoregressive_Token_Generation
- Principle:Shiyu_coder_Kronos_Web_UI_Prediction_Service