Implementation:Huggingface Diffusers DreamBooth Training Loop
| Knowledge Sources | |
|---|---|
| Domains | |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
The core training loop pattern for DreamBooth LoRA fine-tuning, implementing dual-objective denoising loss with optional prior preservation. This pattern combines latent encoding, noise scheduling, UNet forward passes, instance/class loss decomposition via torch.chunk(), and gradient-clipped optimizer steps.
Description
The training loop iterates over epochs and batches from the DreamBooth dataloader. Each step performs:
- Latent encoding -- Batch pixel values are encoded through the frozen VAE to obtain latent representations, then scaled by the VAE's scaling factor.
- Noise injection -- Random Gaussian noise is sampled, random timesteps are drawn, and the noise scheduler adds noise to the latents at the corresponding timestep (forward diffusion).
- Text conditioning -- Text embeddings are obtained either from pre-computed hidden states or by running the text encoder on tokenized prompts.
- UNet prediction -- The noisy latents, timesteps, and text embeddings are passed through the UNet. If the model predicts variance (6 output channels), only the first 3 channels (noise prediction) are kept.
- Target computation -- The target is set based on
prediction_type: raw noise for"epsilon", or velocity for"v_prediction". - Loss computation with prior preservation -- When enabled,
torch.chunk(model_pred, 2, dim=0)splits predictions into instance and class halves. Separate MSE losses are computed and combined:loss = instance_loss + prior_loss_weight * prior_loss. - Gradient update -- Backpropagation, gradient clipping via
accelerator.clip_grad_norm_(), optimizer step, and learning rate scheduler step.
The loop also handles checkpointing at regular intervals, managing checkpoint rotation via checkpoints_total_limit.
Usage
The training loop is invoked as the main training phase after model setup, LoRA injection, and dataloader creation:
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# ... forward pass, loss computation, backward pass ...
Code Reference
Source Location
- Repository:
huggingface/diffusers - File:
examples/dreambooth/train_dreambooth_lora.py(lines 1241--1370)
Signature
# Training loop pattern (inline in main(), not a standalone function)
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
# Encode to latent space
if vae is not None:
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
else:
model_input = pixel_values
# Sample noise and timesteps
noise = torch.randn_like(model_input)
bsz, channels, height, width = model_input.shape
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,),
device=model_input.device
).long()
# Add noise (forward diffusion)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# Get text embeddings
encoder_hidden_states = encode_prompt(
text_encoder, batch["input_ids"], batch["attention_mask"], ...
)
# UNet prediction
model_pred = unet(
noisy_model_input, timesteps, encoder_hidden_states,
class_labels=class_labels, return_dict=False,
)[0]
# Handle variance prediction models
if model_pred.shape[1] == 6:
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
# Compute target
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
# Prior preservation loss decomposition
if args.with_prior_preservation:
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Backward and optimize
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Import
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from diffusers import DDPMScheduler
I/O Contract
Inputs
| Name | Type | Description |
|---|---|---|
batch["pixel_values"] |
Tensor [B, 3, H, W] | Preprocessed images from the DreamBooth dataloader. When prior preservation is enabled, B = 2 * train_batch_size (instance + class concatenated). |
batch["input_ids"] |
Tensor [B, seq_len] | Tokenized prompts. Instance prompts followed by class prompts when prior preservation is enabled. |
args.with_prior_preservation |
bool | Whether to apply prior preservation loss decomposition. |
args.prior_loss_weight |
float | Lambda weight for the prior preservation loss term. |
args.max_grad_norm |
float | Maximum gradient norm for clipping (default 1.0). |
noise_scheduler.config.prediction_type |
str | Either "epsilon" or "v_prediction".
|
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Tensor (scalar) | Combined training loss: instance_loss + lambda * prior_loss (or just instance_loss without prior preservation).
|
| Side effects | Parameter updates | LoRA adapter weights in the UNet (and optionally text encoder) are updated via the optimizer. |
| Checkpoints | Files on disk | Training state checkpoints saved at intervals defined by checkpointing_steps.
|
| Logs | Dict | Loss and learning rate values logged to the configured reporting backend (TensorBoard, W&B, etc.). |
Usage Examples
Example 1: Loss Decomposition with Prior Preservation
Demonstrating how torch.chunk splits the batch for separate loss computation.
# Assume batch_size=4, with_prior_preservation=True
# model_pred shape: [8, 4, 64, 64] (4 instance + 4 class predictions)
# target shape: [8, 4, 64, 64]
model_pred_instance, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target_instance, target_prior = torch.chunk(target, 2, dim=0)
# model_pred_instance shape: [4, 4, 64, 64]
# model_pred_prior shape: [4, 4, 64, 64]
instance_loss = F.mse_loss(model_pred_instance.float(), target_instance.float(), reduction="mean")
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
loss = instance_loss + 1.0 * prior_loss
Example 2: Training Without Prior Preservation
When prior preservation is disabled, no chunking is needed.
# model_pred shape: [4, 4, 64, 64] (instance predictions only)
# target shape: [4, 4, 64, 64]
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")