Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Diffusers LoRA Training Loop

From Leeroopedia
Knowledge Sources
Domains Diffusion_Models, Training_Loops, Loss_Functions
Last Updated 2026-02-13 21:00 GMT

Overview

Concrete tool for executing the forward pass and loss computation of diffusion model LoRA training, as implemented in the Diffusers text-to-image training example.

Description

The LoRA training loop iterates over epochs and batches, performing the diffusion training forward pass within an accelerator.accumulate() context for gradient accumulation. Each step encodes images to latents via the frozen VAE, samples random noise and timesteps, creates noisy latents via the noise scheduler, computes text embeddings with the frozen text encoder, runs the UNet forward pass to predict noise (or velocity), and computes the MSE loss.

The loss supports two prediction types: epsilon (noise prediction, the default for most SD models) and v_prediction (velocity prediction, used by SD 2.x). When snr_gamma is set, Min-SNR loss weighting is applied using the compute_snr utility to reweight the loss per timestep.

After loss computation, backpropagation is performed via accelerator.backward(), gradients are clipped to max_grad_norm when gradients are synchronized, and the optimizer and scheduler are stepped. Training metrics are gathered across processes and logged.

Usage

Use this training loop when:

  • Fine-tuning Stable Diffusion with LoRA
  • You need support for both epsilon and v-prediction models
  • You want optional Min-SNR loss weighting for training stability
  • Training with gradient accumulation across multiple micro-batches

Code Reference

Source Location

  • Repository: diffusers
  • File: examples/text_to_image/train_text_to_image_lora.py
  • Lines: 860-934

Signature

for epoch in range(first_epoch, args.num_train_epochs):
    unet.train()
    train_loss = 0.0
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(unet):
            # Encode images to latent space
            latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

            # Sample noise and timesteps
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)

            # Forward diffusion
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Text conditioning
            encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]

            # Noise prediction
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]

            # Loss computation
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            # Backpropagation
            accelerator.backward(loss)
            accelerator.clip_grad_norm_(lora_layers, args.max_grad_norm)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

Import

import torch
import torch.nn.functional as F

I/O Contract

Inputs

Name Type Required Description
batch["pixel_values"] torch.Tensor Yes Batch of preprocessed images, shape [B, 3, H, W], range [-1, 1].
batch["input_ids"] torch.Tensor Yes Batch of tokenized captions, shape [B, max_length], int64.
prediction_type str No Prediction parameterization: "epsilon" or "v_prediction". Read from the scheduler config.
noise_offset float No Magnitude of per-channel noise offset. Default: 0 (disabled). Typical value: 0.1.
snr_gamma float No Min-SNR gamma parameter. Default: None (disabled). Typical value: 5.0.
max_grad_norm float No Maximum gradient norm for clipping. Default: 1.0.

Outputs

Name Type Description
loss torch.Tensor Scalar MSE loss value (optionally with Min-SNR weighting).
train_loss float Accumulated loss averaged over gradient accumulation steps, gathered across all processes.

Usage Examples

Basic Usage

import torch
import torch.nn.functional as F

for epoch in range(num_epochs):
    unet.train()
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(unet):
            # 1. Encode images to latent space
            latents = vae.encode(
                batch["pixel_values"].to(dtype=weight_dtype)
            ).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

            # 2. Sample noise
            noise = torch.randn_like(latents)

            # 3. Sample random timesteps
            bsz = latents.shape[0]
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps,
                (bsz,), device=latents.device,
            ).long()

            # 4. Add noise to latents (forward diffusion)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # 5. Get text embeddings
            encoder_hidden_states = text_encoder(
                batch["input_ids"], return_dict=False
            )[0]

            # 6. Predict noise
            model_pred = unet(
                noisy_latents, timesteps, encoder_hidden_states,
                return_dict=False,
            )[0]

            # 7. Compute loss (epsilon prediction)
            target = noise
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            # 8. Backpropagate and optimize
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(
                    filter(lambda p: p.requires_grad, unet.parameters()),
                    max_grad_norm=1.0,
                )
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

Min-SNR Weighted Loss

from diffusers.training_utils import compute_snr

snr = compute_snr(noise_scheduler, timesteps)
snr_gamma = 5.0
mse_loss_weights = torch.stack(
    [snr, snr_gamma * torch.ones_like(timesteps)], dim=1
).min(dim=1)[0]

if noise_scheduler.config.prediction_type == "epsilon":
    mse_loss_weights = mse_loss_weights / snr
elif noise_scheduler.config.prediction_type == "v_prediction":
    mse_loss_weights = mse_loss_weights / (snr + 1)

loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment