Implementation:Shiyu coder Kronos QlibDataset Usage
| Field | Value |
|---|---|
| implementation_name | QlibDataset_Usage |
| type | API Doc |
| repository | https://github.com/shiyu-coder/Kronos |
| source_file | finetune/dataset.py:L9-130 |
| implements | Principle:Shiyu_coder_Kronos_Qlib_Training_Dataset |
| last_updated | 2026-02-09 14:00 GMT |
Summary
The QlibDataset class is a PyTorch Dataset that pre-computes all valid sliding window indices from pickled financial data, randomly samples from them during each epoch, and returns instance-normalized feature tensors.
Class
QlibDataset(Dataset)
API Signature
QlibDataset(data_type: str = 'train') -> QlibDataset
Import
from dataset import QlibDataset
Dependencies
picklerandomnumpytorchtorch.utils.data.Datasetconfig.Config
Input
data_type(str): Either'train'or'val'. Determines which pickle file to load and the number of samples per epoch.- Reads
Configinternally for paths, window parameters, feature lists, and clip values.
Output
__getitem__ returns a tuple (x_tensor, x_stamp_tensor):
| Tensor | Shape | Description |
|---|---|---|
x_tensor |
(window, n_features) |
Normalized feature window (float32) |
x_stamp_tensor |
(window, n_time_features) |
Time feature window (float32) |
Where:
window = lookback_window + predict_window + 1 = 101(with defaults)n_features = 6(open, high, low, close, vol, amt)n_time_features = 5(minute, hour, weekday, day, month)
Constructor Details
def __init__(self, data_type: str = 'train'):
- Creates a
Configinstance internally - Validates
data_typeis'train'or'val'(raisesValueErrorotherwise) - Loads the corresponding pickle file (
train_data.pklorval_data.pkl) - Sets
n_samplesfromConfig.n_train_iter(100,000) orConfig.n_val_iter(20,000) - Initializes a dedicated
random.Randominstance withConfig.seed - Computes window size:
lookback_window + predict_window + 1 - Pre-computes all valid
(symbol, start_index)pairs - Generates time features (minute, hour, weekday, day, month) from the datetime column
- Caps
n_samplesat the total number of available sample indices
Key Methods
set_epoch_seed(epoch)
def set_epoch_seed(self, epoch: int) -> None
Sets a new seed for the random sampler: Config.seed + epoch. This is crucial for reproducibility in distributed training (DDP), ensuring each epoch sees a different but deterministic ordering.
__len__()
def __len__(self) -> int
Returns self.n_samples, the number of samples per epoch.
__getitem__(idx)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]
Note: The idx argument is ignored. Instead, a random index is drawn from the pre-computed self.indices list using self.py_rng. This ensures random sampling over the entire dataset for each call.
Steps:
- Draw a random
(symbol, start_idx)pair - Extract the sliding window of size
windowfrom the symbol's data - Separate main features and time features
- Apply instance-level normalization:
(x - mean) / (std + 1e-5) - Clip normalized values to
[-clip, clip](default[-5.0, 5.0]) - Return as PyTorch float32 tensors
Normalization
Instance-level normalization is computed per window:
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x = (x - x_mean) / (x_std + 1e-5)
x = np.clip(x, -self.config.clip, self.config.clip)
- Mean and standard deviation are computed per feature (axis=0) across the time dimension
- A small epsilon (1e-5) prevents division by zero
- Clipping to
[-5.0, 5.0]suppresses extreme outliers
Example Usage
from dataset import QlibDataset
train_dataset = QlibDataset(data_type='train')
print(f"Dataset length: {len(train_dataset)}")
# Get a sample (idx is ignored, random sampling used)
x, x_stamp = train_dataset[0]
print(f"Feature shape: {x.shape}") # torch.Size([101, 6])
print(f"Time feature shape: {x_stamp.shape}") # torch.Size([101, 5])
# Set epoch seed for DDP reproducibility
train_dataset.set_epoch_seed(epoch=1)
Source Reference
File: finetune/dataset.py, lines 9-130.