Implementation:Fastai Fastbook SGD Manual
| Knowledge Sources | |
|---|---|
| Domains | Deep Learning, Optimization |
| Last Updated | 2026-02-09 17:00 GMT |
Overview
Concrete pattern for implementing Stochastic Gradient Descent from scratch using PyTorch tensor operations, requires_grad_, and backward().
Description
This implementation walks through the complete SGD loop as built in fastbook Chapter 4. It starts with a synthetic quadratic-fitting example to illustrate the seven-step process, then applies it to MNIST digit classification. The user implements every step manually: random initialization, prediction, loss computation, gradient calculation via backward(), parameter update, and gradient zeroing.
Usage
Use this pattern when learning SGD from scratch or when you need to understand every detail of the optimization loop before delegating to a library optimizer like torch.optim.SGD.
Code Reference
Source Location
- Repository: fastbook
- File: 04_mnist_basics.ipynb (Chapter 4), "SGD from scratch" section
Signature
The key interfaces the user implements:
# 1. Initialize parameters with requires_grad
def init_params(size, std=1.0):
return (torch.randn(size) * std).requires_grad_()
# 2. Define a model (e.g., linear)
def linear1(xb):
return xb @ weights + bias
# 3. Define a loss function
def mnist_loss(predictions, targets):
predictions = predictions.sigmoid()
return torch.where(targets == 1, 1 - predictions, predictions).mean()
# 4. Compute gradients
def calc_grad(xb, yb, model):
preds = model(xb)
loss = mnist_loss(preds, yb)
loss.backward()
# 5. Update parameters (one training epoch)
def train_epoch(model, lr, params):
for xb, yb in dl:
calc_grad(xb, yb, model)
for p in params:
p.data -= p.grad * lr
p.grad.zero_()
# 6. Validate
def validate_epoch(model):
accs = [batch_accuracy(model(xb), yb) for xb, yb in valid_dl]
return round(torch.stack(accs).mean().item(), 4)
Import
import torch
from torch.utils.data import DataLoader
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| training data | Tensor pairs (x, y) |
Yes | Feature tensors and target tensors, loaded via DataLoader
|
| initial params | list[Tensor] with requires_grad=True |
Yes | Randomly initialized weight and bias tensors |
| lr | float |
Yes | Learning rate (e.g., 1.0 for the MNIST linear model in the book)
|
| batch_size | int |
Yes | Number of samples per mini-batch (e.g., 256)
|
Outputs
| Name | Type | Description |
|---|---|---|
| optimized params | list[Tensor] |
Weight and bias tensors whose values minimize the loss after training |
| validation accuracy | float |
Fraction of correctly classified samples in the validation set per epoch |
Usage Examples
Basic Usage: Quadratic Curve Fitting
import torch
import matplotlib.pyplot as plt
# Generate synthetic data
time = torch.arange(0, 20).float()
speed = torch.randn(20) * 3 + 0.75 * (time - 9.5)**2 + 1
# Define quadratic model
def f(t, params):
a, b, c = params
return a * (t**2) + (b * t) + c
# Define loss
def mse(preds, targets):
return ((preds - targets)**2).mean()
# Step 1: Initialize
params = torch.randn(3).requires_grad_()
lr = 1e-5
# Steps 2-7: Training loop
for i in range(2000):
preds = f(time, params) # Step 2: predict
loss = mse(preds, speed) # Step 3: loss
loss.backward() # Step 4: gradients
params.data -= lr * params.grad # Step 5: update
params.grad.zero_() # Step 6: zero grads
# Step 7: repeat
print(f"Fitted params: a={params[0]:.4f}, b={params[1]:.4f}, c={params[2]:.4f}")
Full MNIST SGD Loop
from fastai.vision.all import *
# Prepare data
path = untar_data(URLs.MNIST_SAMPLE)
threes = (path/'train'/'3').ls().sorted()
sevens = (path/'train'/'7').ls().sorted()
three_tensors = torch.stack([tensor(Image.open(o)) for o in threes]).float() / 255
seven_tensors = torch.stack([tensor(Image.open(o)) for o in sevens]).float() / 255
train_x = torch.cat([three_tensors, seven_tensors]).view(-1, 28*28)
train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)
dset = list(zip(train_x, train_y))
dl = DataLoader(dset, batch_size=256, shuffle=True)
# Initialize parameters
weights = init_params((28*28, 1))
bias = init_params(1)
lr = 1.0
# Training loop
for epoch in range(20):
for xb, yb in dl:
preds = xb @ weights + bias # predict
loss = mnist_loss(preds, yb) # loss
loss.backward() # gradients
weights.data -= weights.grad * lr # update weights
bias.data -= bias.grad * lr # update bias
weights.grad.zero_() # zero grads
bias.grad.zero_()
print(f"Epoch {epoch}: accuracy = {validate_epoch(linear1)}")