Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Diffusers DreamBooth Training Loop

From Leeroopedia
Metadata
Knowledge Sources
Domains
Last Updated 2026-02-13 00:00 GMT

Overview

The core training loop pattern for DreamBooth LoRA fine-tuning, implementing dual-objective denoising loss with optional prior preservation. This pattern combines latent encoding, noise scheduling, UNet forward passes, instance/class loss decomposition via torch.chunk(), and gradient-clipped optimizer steps.

Description

The training loop iterates over epochs and batches from the DreamBooth dataloader. Each step performs:

  1. Latent encoding -- Batch pixel values are encoded through the frozen VAE to obtain latent representations, then scaled by the VAE's scaling factor.
  2. Noise injection -- Random Gaussian noise is sampled, random timesteps are drawn, and the noise scheduler adds noise to the latents at the corresponding timestep (forward diffusion).
  3. Text conditioning -- Text embeddings are obtained either from pre-computed hidden states or by running the text encoder on tokenized prompts.
  4. UNet prediction -- The noisy latents, timesteps, and text embeddings are passed through the UNet. If the model predicts variance (6 output channels), only the first 3 channels (noise prediction) are kept.
  5. Target computation -- The target is set based on prediction_type: raw noise for "epsilon", or velocity for "v_prediction".
  6. Loss computation with prior preservation -- When enabled, torch.chunk(model_pred, 2, dim=0) splits predictions into instance and class halves. Separate MSE losses are computed and combined: loss = instance_loss + prior_loss_weight * prior_loss.
  7. Gradient update -- Backpropagation, gradient clipping via accelerator.clip_grad_norm_(), optimizer step, and learning rate scheduler step.

The loop also handles checkpointing at regular intervals, managing checkpoint rotation via checkpoints_total_limit.

Usage

The training loop is invoked as the main training phase after model setup, LoRA injection, and dataloader creation:

for epoch in range(first_epoch, args.num_train_epochs):
    unet.train()
    if args.train_text_encoder:
        text_encoder.train()
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(unet):
            # ... forward pass, loss computation, backward pass ...

Code Reference

Source Location

  • Repository: huggingface/diffusers
  • File: examples/dreambooth/train_dreambooth_lora.py (lines 1241--1370)

Signature

# Training loop pattern (inline in main(), not a standalone function)

for epoch in range(first_epoch, args.num_train_epochs):
    unet.train()
    if args.train_text_encoder:
        text_encoder.train()
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(unet):
            pixel_values = batch["pixel_values"].to(dtype=weight_dtype)

            # Encode to latent space
            if vae is not None:
                model_input = vae.encode(pixel_values).latent_dist.sample()
                model_input = model_input * vae.config.scaling_factor
            else:
                model_input = pixel_values

            # Sample noise and timesteps
            noise = torch.randn_like(model_input)
            bsz, channels, height, width = model_input.shape
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bsz,),
                device=model_input.device
            ).long()

            # Add noise (forward diffusion)
            noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)

            # Get text embeddings
            encoder_hidden_states = encode_prompt(
                text_encoder, batch["input_ids"], batch["attention_mask"], ...
            )

            # UNet prediction
            model_pred = unet(
                noisy_model_input, timesteps, encoder_hidden_states,
                class_labels=class_labels, return_dict=False,
            )[0]

            # Handle variance prediction models
            if model_pred.shape[1] == 6:
                model_pred, _ = torch.chunk(model_pred, 2, dim=1)

            # Compute target
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(model_input, noise, timesteps)

            # Prior preservation loss decomposition
            if args.with_prior_preservation:
                model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
                target, target_prior = torch.chunk(target, 2, dim=0)
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
                prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
                loss = loss + args.prior_loss_weight * prior_loss
            else:
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            # Backward and optimize
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

Import

import torch
import torch.nn.functional as F
from accelerate import Accelerator
from diffusers import DDPMScheduler

I/O Contract

Inputs

Input Contract
Name Type Description
batch["pixel_values"] Tensor [B, 3, H, W] Preprocessed images from the DreamBooth dataloader. When prior preservation is enabled, B = 2 * train_batch_size (instance + class concatenated).
batch["input_ids"] Tensor [B, seq_len] Tokenized prompts. Instance prompts followed by class prompts when prior preservation is enabled.
args.with_prior_preservation bool Whether to apply prior preservation loss decomposition.
args.prior_loss_weight float Lambda weight for the prior preservation loss term.
args.max_grad_norm float Maximum gradient norm for clipping (default 1.0).
noise_scheduler.config.prediction_type str Either "epsilon" or "v_prediction".

Outputs

Output Contract
Name Type Description
loss Tensor (scalar) Combined training loss: instance_loss + lambda * prior_loss (or just instance_loss without prior preservation).
Side effects Parameter updates LoRA adapter weights in the UNet (and optionally text encoder) are updated via the optimizer.
Checkpoints Files on disk Training state checkpoints saved at intervals defined by checkpointing_steps.
Logs Dict Loss and learning rate values logged to the configured reporting backend (TensorBoard, W&B, etc.).

Usage Examples

Example 1: Loss Decomposition with Prior Preservation

Demonstrating how torch.chunk splits the batch for separate loss computation.

# Assume batch_size=4, with_prior_preservation=True
# model_pred shape: [8, 4, 64, 64]  (4 instance + 4 class predictions)
# target shape:     [8, 4, 64, 64]

model_pred_instance, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target_instance, target_prior = torch.chunk(target, 2, dim=0)

# model_pred_instance shape: [4, 4, 64, 64]
# model_pred_prior shape:    [4, 4, 64, 64]

instance_loss = F.mse_loss(model_pred_instance.float(), target_instance.float(), reduction="mean")
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

loss = instance_loss + 1.0 * prior_loss

Example 2: Training Without Prior Preservation

When prior preservation is disabled, no chunking is needed.

# model_pred shape: [4, 4, 64, 64] (instance predictions only)
# target shape:     [4, 4, 64, 64]

loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

Related Pages

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment