Implementation:NVIDIA NeMo Aligner Train SDXL DRaFTP
| Knowledge Sources | |
|---|---|
| Domains | Multimodal, Image Generation, Diffusion Models, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
train_sdxl_draftp.py is the entry-point training script for DRaFT+ fine-tuning of Stable Diffusion XL (SDXL) models, with FSDP support, activation checkpointing, and a custom trainer builder.
Description
This script sets up the complete DRaFT+ training pipeline for SDXL models. It is more complex than the standard SD training script due to SDXL's larger architecture requiring additional memory optimization strategies:
- Custom Trainer Builder -- Defines a local MegatronStableDiffusionTrainerBuilder that overrides _training_strategy to support FSDP with custom extra_fsdp_wrap_module sets covering UNetModel, TimestepEmbedSequential, Decoder, ResnetBlock, AttnBlock, MegatronCLIPRewardModel, FrozenOpenCLIPEmbedder, FrozenOpenCLIPEmbedder2, FrozenCLIPEmbedder, and ParallelLinearAdapter.
- CUDA device setup -- Explicitly sets the CUDA device based on LOCAL_RANK for multi-GPU training.
- Model instantiation -- Creates a MegatronSDXLDRaFTPModel directly and initializes PEFT adapters. The reward model is loaded and attached before distributed initialization.
- Activation checkpointing -- Optionally applies non-reentrant activation checkpointing on Decoder, UNetModel, and MegatronCLIPRewardModel modules when cfg.model.activation_checkpointing is True.
- Dataset construction -- Builds train and validation datasets from text_webdataset, extracting only the "captions" field. Dataset paths are set per device per node.
- Training execution -- Creates a SupervisedTrainer (with run_init_validation=False) and calls fit() after emptying the CUDA cache.
The script prints the model architecture on rank 0 and synchronizes all ranks with a distributed barrier before training begins.
Usage
Run this script to perform DRaFT+ training on SDXL models. Supports FSDP for memory-efficient training of the large SDXL architecture. Requires a text-caption webdataset and a pretrained reward model checkpoint.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: examples/mm/stable_diffusion/train_sdxl_draftp.py
- Lines: 1-272
Signature
class MegatronStableDiffusionTrainerBuilder(MegatronTrainerBuilder):
def _training_strategy(self) -> NLPDDPStrategy:
def resolve_and_create_trainer(cfg, pop_trainer_key):
@hydra_runner(config_path="conf", config_name="draftp_sdxl")
def main(cfg) -> None:
Import
from nemo_aligner.models.mm.stable_diffusion.megatron_sdxl_draftp_model import MegatronSDXLDRaFTPModel
from nemo_aligner.models.mm.stable_diffusion.image_text_rms import MegatronCLIPRewardModel, get_reward_model
from nemo_aligner.algorithms.supervised import SupervisedTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg.model | DictConfig | Yes | SDXL model configuration with sampling.base, truncation_steps, kl_coeff, peft, micro_batch_size, global_batch_size |
| cfg.model.data.webdataset.local_root_path | str | Yes | Path to the webdataset containing text captions |
| cfg.rm | DictConfig | Yes | Reward model configuration |
| cfg.trainer.draftp_sd | DictConfig | Yes | DRaFT+ trainer-specific configuration |
| cfg.model.activation_checkpointing | bool | No | Enable activation checkpointing on UNet and Decoder (default: False) |
| cfg.model.fsdp | bool | No | Enable FSDP sharding (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| Trained model checkpoint | File | Saved checkpoint of the DRaFT+ fine-tuned SDXL model |
| Training logs | Logs | Loss and KL penalty metrics logged via the experiment logger |
Usage Examples
# Command-line invocation:
# python examples/mm/stable_diffusion/train_sdxl_draftp.py \
# model.data.webdataset.local_root_path=/path/to/webdataset \
# rm.model.restore_from_path=/path/to/reward_model.nemo \
# model.truncation_steps=5 \
# model.kl_coeff=0.1 \
# model.fsdp=True \
# model.activation_checkpointing=True \
# model.peft.peft_scheme=lora