Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Shiyu coder Kronos QlibDataset Usage

From Leeroopedia


Field Value
implementation_name QlibDataset_Usage
type API Doc
repository https://github.com/shiyu-coder/Kronos
source_file finetune/dataset.py:L9-130
implements Principle:Shiyu_coder_Kronos_Qlib_Training_Dataset
last_updated 2026-02-09 14:00 GMT

Summary

The QlibDataset class is a PyTorch Dataset that pre-computes all valid sliding window indices from pickled financial data, randomly samples from them during each epoch, and returns instance-normalized feature tensors.

Class

QlibDataset(Dataset)

API Signature

QlibDataset(data_type: str = 'train') -> QlibDataset

Import

from dataset import QlibDataset

Dependencies

  • pickle
  • random
  • numpy
  • torch
  • torch.utils.data.Dataset
  • config.Config

Input

  • data_type (str): Either 'train' or 'val'. Determines which pickle file to load and the number of samples per epoch.
  • Reads Config internally for paths, window parameters, feature lists, and clip values.

Output

__getitem__ returns a tuple (x_tensor, x_stamp_tensor):

Tensor Shape Description
x_tensor (window, n_features) Normalized feature window (float32)
x_stamp_tensor (window, n_time_features) Time feature window (float32)

Where:

  • window = lookback_window + predict_window + 1 = 101 (with defaults)
  • n_features = 6 (open, high, low, close, vol, amt)
  • n_time_features = 5 (minute, hour, weekday, day, month)

Constructor Details

def __init__(self, data_type: str = 'train'):
  • Creates a Config instance internally
  • Validates data_type is 'train' or 'val' (raises ValueError otherwise)
  • Loads the corresponding pickle file (train_data.pkl or val_data.pkl)
  • Sets n_samples from Config.n_train_iter (100,000) or Config.n_val_iter (20,000)
  • Initializes a dedicated random.Random instance with Config.seed
  • Computes window size: lookback_window + predict_window + 1
  • Pre-computes all valid (symbol, start_index) pairs
  • Generates time features (minute, hour, weekday, day, month) from the datetime column
  • Caps n_samples at the total number of available sample indices

Key Methods

set_epoch_seed(epoch)

def set_epoch_seed(self, epoch: int) -> None

Sets a new seed for the random sampler: Config.seed + epoch. This is crucial for reproducibility in distributed training (DDP), ensuring each epoch sees a different but deterministic ordering.

__len__()

def __len__(self) -> int

Returns self.n_samples, the number of samples per epoch.

__getitem__(idx)

def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]

Note: The idx argument is ignored. Instead, a random index is drawn from the pre-computed self.indices list using self.py_rng. This ensures random sampling over the entire dataset for each call.

Steps:

  1. Draw a random (symbol, start_idx) pair
  2. Extract the sliding window of size window from the symbol's data
  3. Separate main features and time features
  4. Apply instance-level normalization: (x - mean) / (std + 1e-5)
  5. Clip normalized values to [-clip, clip] (default [-5.0, 5.0])
  6. Return as PyTorch float32 tensors

Normalization

Instance-level normalization is computed per window:

x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x = (x - x_mean) / (x_std + 1e-5)
x = np.clip(x, -self.config.clip, self.config.clip)
  • Mean and standard deviation are computed per feature (axis=0) across the time dimension
  • A small epsilon (1e-5) prevents division by zero
  • Clipping to [-5.0, 5.0] suppresses extreme outliers

Example Usage

from dataset import QlibDataset

train_dataset = QlibDataset(data_type='train')
print(f"Dataset length: {len(train_dataset)}")

# Get a sample (idx is ignored, random sampling used)
x, x_stamp = train_dataset[0]
print(f"Feature shape: {x.shape}")       # torch.Size([101, 6])
print(f"Time feature shape: {x_stamp.shape}")  # torch.Size([101, 5])

# Set epoch seed for DDP reproducibility
train_dataset.set_epoch_seed(epoch=1)

Source Reference

File: finetune/dataset.py, lines 9-130.

Environment & Heuristic Links

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment