Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Fastai Fastbook SGD Manual

From Leeroopedia


Knowledge Sources
Domains Deep Learning, Optimization
Last Updated 2026-02-09 17:00 GMT

Overview

Concrete pattern for implementing Stochastic Gradient Descent from scratch using PyTorch tensor operations, requires_grad_, and backward().

Description

This implementation walks through the complete SGD loop as built in fastbook Chapter 4. It starts with a synthetic quadratic-fitting example to illustrate the seven-step process, then applies it to MNIST digit classification. The user implements every step manually: random initialization, prediction, loss computation, gradient calculation via backward(), parameter update, and gradient zeroing.

Usage

Use this pattern when learning SGD from scratch or when you need to understand every detail of the optimization loop before delegating to a library optimizer like torch.optim.SGD.

Code Reference

Source Location

  • Repository: fastbook
  • File: 04_mnist_basics.ipynb (Chapter 4), "SGD from scratch" section

Signature

The key interfaces the user implements:

# 1. Initialize parameters with requires_grad
def init_params(size, std=1.0):
    return (torch.randn(size) * std).requires_grad_()

# 2. Define a model (e.g., linear)
def linear1(xb):
    return xb @ weights + bias

# 3. Define a loss function
def mnist_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets == 1, 1 - predictions, predictions).mean()

# 4. Compute gradients
def calc_grad(xb, yb, model):
    preds = model(xb)
    loss = mnist_loss(preds, yb)
    loss.backward()

# 5. Update parameters (one training epoch)
def train_epoch(model, lr, params):
    for xb, yb in dl:
        calc_grad(xb, yb, model)
        for p in params:
            p.data -= p.grad * lr
            p.grad.zero_()

# 6. Validate
def validate_epoch(model):
    accs = [batch_accuracy(model(xb), yb) for xb, yb in valid_dl]
    return round(torch.stack(accs).mean().item(), 4)

Import

import torch
from torch.utils.data import DataLoader

I/O Contract

Inputs

Name Type Required Description
training data Tensor pairs (x, y) Yes Feature tensors and target tensors, loaded via DataLoader
initial params list[Tensor] with requires_grad=True Yes Randomly initialized weight and bias tensors
lr float Yes Learning rate (e.g., 1.0 for the MNIST linear model in the book)
batch_size int Yes Number of samples per mini-batch (e.g., 256)

Outputs

Name Type Description
optimized params list[Tensor] Weight and bias tensors whose values minimize the loss after training
validation accuracy float Fraction of correctly classified samples in the validation set per epoch

Usage Examples

Basic Usage: Quadratic Curve Fitting

import torch
import matplotlib.pyplot as plt

# Generate synthetic data
time = torch.arange(0, 20).float()
speed = torch.randn(20) * 3 + 0.75 * (time - 9.5)**2 + 1

# Define quadratic model
def f(t, params):
    a, b, c = params
    return a * (t**2) + (b * t) + c

# Define loss
def mse(preds, targets):
    return ((preds - targets)**2).mean()

# Step 1: Initialize
params = torch.randn(3).requires_grad_()
lr = 1e-5

# Steps 2-7: Training loop
for i in range(2000):
    preds = f(time, params)             # Step 2: predict
    loss = mse(preds, speed)            # Step 3: loss
    loss.backward()                     # Step 4: gradients
    params.data -= lr * params.grad     # Step 5: update
    params.grad.zero_()                 # Step 6: zero grads
    # Step 7: repeat

print(f"Fitted params: a={params[0]:.4f}, b={params[1]:.4f}, c={params[2]:.4f}")

Full MNIST SGD Loop

from fastai.vision.all import *

# Prepare data
path = untar_data(URLs.MNIST_SAMPLE)
threes = (path/'train'/'3').ls().sorted()
sevens = (path/'train'/'7').ls().sorted()

three_tensors = torch.stack([tensor(Image.open(o)) for o in threes]).float() / 255
seven_tensors = torch.stack([tensor(Image.open(o)) for o in sevens]).float() / 255

train_x = torch.cat([three_tensors, seven_tensors]).view(-1, 28*28)
train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)

dset = list(zip(train_x, train_y))
dl = DataLoader(dset, batch_size=256, shuffle=True)

# Initialize parameters
weights = init_params((28*28, 1))
bias = init_params(1)
lr = 1.0

# Training loop
for epoch in range(20):
    for xb, yb in dl:
        preds = xb @ weights + bias         # predict
        loss = mnist_loss(preds, yb)         # loss
        loss.backward()                      # gradients
        weights.data -= weights.grad * lr    # update weights
        bias.data -= bias.grad * lr          # update bias
        weights.grad.zero_()                 # zero grads
        bias.grad.zero_()
    print(f"Epoch {epoch}: accuracy = {validate_epoch(linear1)}")

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment