Implementation:Ray project Ray PyTorch Hyperparameter Tuning Tutorial
| Knowledge Sources | Ray |
|---|---|
| Domains | Hyperparameter_Tuning, Machine_Learning, Documentation |
| Last Updated | 2026-02-13 |
Overview
PyTorch Hyperparameter Tuning Tutorial is a self-contained Python tutorial script that demonstrates how to integrate Ray Tune with PyTorch for distributed hyperparameter tuning of a CIFAR-10 image classifier.
Description
This tutorial script (pytorch_tutorials_hyperparameter_tuning_tutorial.py) defines a convolutional neural network (Net) with configurable fully connected layer sizes, wraps training and data loading into Ray Tune-compatible functions, and configures a search space for learning rate, batch size, and layer sizes. It uses the ASHAScheduler for early stopping of poorly performing trials, supports multi-GPU training via nn.DataParallel, and implements checkpoint saving and restoration for fault tolerance. The script is mirrored from the official PyTorch tutorials and is kept in sync via hash checks in test_hashes.py.
Usage
Use this file as a reference tutorial for integrating Ray Tune into PyTorch training workflows. It is primarily used as documentation content that appears on the Ray docs site and is validated for consistency with the upstream PyTorch tutorial. Modify this file if the upstream PyTorch tutorial changes or if Ray Tune's API changes require updates.
Code Reference
Source Location
doc/external/pytorch_tutorials_hyperparameter_tuning_tutorial.py
Signature
def load_data(data_dir="./data"):
"""Load CIFAR-10 training and test datasets."""
...
class Net(nn.Module):
def __init__(self, l1=120, l2=84):
"""CNN with configurable fully connected layer sizes."""
...
def train_cifar(config, data_dir=None):
"""Train a CIFAR-10 classifier with Ray Tune-compatible config."""
...
def test_accuracy(net, device="cpu"):
"""Evaluate model accuracy on the CIFAR-10 test set."""
...
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
"""Run the full hyperparameter tuning experiment."""
...
Import
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import ASHAScheduler
I/O Contract
Inputs
| Parameter | Type | Default | Description |
|---|---|---|---|
| config["l1"] | int | 120 | Size of the first fully connected layer |
| config["l2"] | int | 84 | Size of the second fully connected layer |
| config["lr"] | float | (search space) | Learning rate for SGD optimizer |
| config["batch_size"] | int | (search space) | Batch size for data loaders |
| data_dir | str | "./data" | Directory to download/store CIFAR-10 data |
| num_samples | int | 10 | Number of hyperparameter combinations to try |
| max_num_epochs | int | 10 | Maximum number of training epochs per trial |
| gpus_per_trial | int | 2 | Number of GPUs allocated per trial |
Outputs
| Output | Type | Description |
|---|---|---|
| Validation loss | float | Per-trial validation loss reported to Ray Tune |
| Validation accuracy | float | Per-trial validation accuracy reported to Ray Tune |
| Checkpoint | ray.train.Checkpoint | Saved model state dict and optimizer state per epoch |
| Best trial config | dict | Configuration of the best performing trial |
| Test set accuracy | float | Accuracy of the best model on the held-out test set |
Usage Examples
Running the tutorial as a standalone script:
# Run with default settings (10 samples, 10 epochs, no GPU)
if __name__ == "__main__":
main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)
Configuring the search space for Ray Tune:
from ray import tune
config = {
"l1": tune.choice([2**i for i in range(9)]),
"l2": tune.choice([2**i for i in range(9)]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([2, 4, 8, 16]),
}
Using the ASHAScheduler for early stopping:
from ray.tune.schedulers import ASHAScheduler
scheduler = ASHAScheduler(
metric="loss",
mode="min",
max_t=max_num_epochs,
grace_period=1,
reduction_factor=2,
)
result = tune.run(
partial(train_cifar, data_dir=data_dir),
resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
config=config,
num_samples=num_samples,
scheduler=scheduler,
)
Related Pages
- Ray_project_Ray_Sphinx_Configuration - Sphinx config that builds the Ray documentation site
- Ray_project_Ray_Custom_Sphinx_Directives - Custom directives used in Ray documentation
- Ray - Ray project repository