Implementation:Fastai Fastbook Training Loop Manual
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Deep Learning, Optimization, Software Engineering |
| Last Updated | 2026-02-09 17:00 GMT |
Overview
Concrete pattern for constructing a complete training loop from scratch, progressing from a raw manual loop through an optimizer abstraction, as demonstrated in fastbook Chapters 4 and 16.
Description
This implementation shows three stages of the training loop:
- Fully manual: The loop directly computes gradients and updates
params.data -= lr * params.grad. - With BasicOptim: The update and zero_grad logic are encapsulated in a simple optimizer class.
- With PyTorch/fastai SGD: The loop uses
torch.optim.SGDor fastai'sSGDas a drop-in replacement.
Each stage produces identical results, demonstrating how library abstractions wrap the same fundamental loop.
Usage
Use this pattern when:
- Building a training loop from the ground up for educational purposes.
- Prototyping a custom training procedure before adopting a framework.
- Understanding what
Learner.fit()does internally.
Code Reference
Source Location
- Repository: fastbook
- File: 04_mnist_basics.ipynb (Chapter 4), "Putting It All Together" and "Creating an Optimizer" sections
- File: 16_accel_sgd.ipynb (Chapter 16), "Optimizer" section
Signature
Stage 1: Fully manual loop
def calc_grad(xb, yb, model):
preds = model(xb)
loss = loss_func(preds, yb)
loss.backward()
def train_epoch(model, lr, params):
for xb, yb in dl:
calc_grad(xb, yb, model)
for p in params:
p.data -= p.grad * lr
p.grad.zero_()
def validate_epoch(model):
accs = [batch_accuracy(model(xb), yb) for xb, yb in valid_dl]
return round(torch.stack(accs).mean().item(), 4)
Stage 2: BasicOptim class
class BasicOptim:
def __init__(self, params, lr):
self.params, self.lr = list(params), lr
def step(self, *args, **kwargs):
for p in self.params:
p.data -= p.grad.data * self.lr
def zero_grad(self, *args, **kwargs):
for p in self.params:
p.grad = None
Stage 3: Full training function with optimizer
def train_epoch(model):
for xb, yb in dl:
calc_grad(xb, yb, model)
opt.step()
opt.zero_grad()
def train_model(model, epochs):
for i in range(epochs):
train_epoch(model)
print(validate_epoch(model), end=' ')
Import
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module or callable |
Yes | The neural network to train |
| dl | DataLoader |
Yes | Training DataLoader providing (xb, yb) batches
|
| valid_dl | DataLoader |
Yes | Validation DataLoader for metric computation |
| loss_func | callable | Yes | Loss function taking (predictions, targets) and returning a scalar loss
|
| lr | float |
Yes | Learning rate for parameter updates |
| n_epochs | int |
Yes | Number of full passes through the training data |
Outputs
| Name | Type | Description |
|---|---|---|
| trained model | nn.Module |
Model with optimized parameters after training |
| epoch metrics | list[float] |
Validation accuracy (or other metric) printed after each epoch |
Usage Examples
Basic Usage: Fully Manual Loop
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
# Assume train_x, train_y, valid_x, valid_y are prepared tensors
dset = list(zip(train_x, train_y))
dl = DataLoader(dset, batch_size=256, shuffle=True)
valid_dset = list(zip(valid_x, valid_y))
valid_dl = DataLoader(valid_dset, batch_size=256)
# Define model and loss
def linear1(xb): return xb @ weights + bias
def mnist_loss(predictions, targets):
predictions = predictions.sigmoid()
return torch.where(targets == 1, 1 - predictions, predictions).mean()
def batch_accuracy(xb, yb):
preds = xb.sigmoid()
correct = (preds > 0.5) == (yb == 1)
return correct.float().mean()
# Initialize parameters
weights = (torch.randn(28*28, 1)).requires_grad_()
bias = (torch.randn(1)).requires_grad_()
lr = 1.0
# Training loop
for epoch in range(20):
for xb, yb in dl:
preds = linear1(xb)
loss = mnist_loss(preds, yb)
loss.backward()
weights.data -= weights.grad * lr
bias.data -= bias.grad * lr
weights.grad.zero_()
bias.grad.zero_()
# Validation
accs = [batch_accuracy(linear1(xb), yb) for xb, yb in valid_dl]
val_acc = round(torch.stack(accs).mean().item(), 4)
print(f"Epoch {epoch}: {val_acc}")
Using BasicOptim
linear_model = nn.Linear(28*28, 1)
opt = BasicOptim(linear_model.parameters(), lr=1.0)
def calc_grad(xb, yb, model):
preds = model(xb)
loss = mnist_loss(preds, yb)
loss.backward()
for epoch in range(20):
for xb, yb in dl:
calc_grad(xb, yb, linear_model)
opt.step()
opt.zero_grad()
print(validate_epoch(linear_model), end=' ')
Upgrading to a Neural Network
# Replace linear model with neural network
simple_net = nn.Sequential(
nn.Linear(28*28, 30),
nn.ReLU(),
nn.Linear(30, 1)
)
# Same loop works with any model
opt = BasicOptim(simple_net.parameters(), lr=0.1)
train_model(simple_net, epochs=40)
Using fastai SGD (Drop-in Replacement)
from fastai.optimizer import SGD
linear_model = nn.Linear(28*28, 1)
opt = SGD(linear_model.parameters(), lr=1.0)
train_model(linear_model, 20)
# Output: 0.4932 0.8618 0.8203 0.9102 ...
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment