Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner Train SDXL DRaFTP

From Leeroopedia


Knowledge Sources
Domains Multimodal, Image Generation, Diffusion Models, Alignment
Last Updated 2026-02-08 00:00 GMT

Overview

train_sdxl_draftp.py is the entry-point training script for DRaFT+ fine-tuning of Stable Diffusion XL (SDXL) models, with FSDP support, activation checkpointing, and a custom trainer builder.

Description

This script sets up the complete DRaFT+ training pipeline for SDXL models. It is more complex than the standard SD training script due to SDXL's larger architecture requiring additional memory optimization strategies:

  1. Custom Trainer Builder -- Defines a local MegatronStableDiffusionTrainerBuilder that overrides _training_strategy to support FSDP with custom extra_fsdp_wrap_module sets covering UNetModel, TimestepEmbedSequential, Decoder, ResnetBlock, AttnBlock, MegatronCLIPRewardModel, FrozenOpenCLIPEmbedder, FrozenOpenCLIPEmbedder2, FrozenCLIPEmbedder, and ParallelLinearAdapter.
  2. CUDA device setup -- Explicitly sets the CUDA device based on LOCAL_RANK for multi-GPU training.
  3. Model instantiation -- Creates a MegatronSDXLDRaFTPModel directly and initializes PEFT adapters. The reward model is loaded and attached before distributed initialization.
  4. Activation checkpointing -- Optionally applies non-reentrant activation checkpointing on Decoder, UNetModel, and MegatronCLIPRewardModel modules when cfg.model.activation_checkpointing is True.
  5. Dataset construction -- Builds train and validation datasets from text_webdataset, extracting only the "captions" field. Dataset paths are set per device per node.
  6. Training execution -- Creates a SupervisedTrainer (with run_init_validation=False) and calls fit() after emptying the CUDA cache.

The script prints the model architecture on rank 0 and synchronizes all ranks with a distributed barrier before training begins.

Usage

Run this script to perform DRaFT+ training on SDXL models. Supports FSDP for memory-efficient training of the large SDXL architecture. Requires a text-caption webdataset and a pretrained reward model checkpoint.

Code Reference

Source Location

  • Repository: NVIDIA_NeMo_Aligner
  • File: examples/mm/stable_diffusion/train_sdxl_draftp.py
  • Lines: 1-272

Signature

class MegatronStableDiffusionTrainerBuilder(MegatronTrainerBuilder):
    def _training_strategy(self) -> NLPDDPStrategy:

def resolve_and_create_trainer(cfg, pop_trainer_key):

@hydra_runner(config_path="conf", config_name="draftp_sdxl")
def main(cfg) -> None:

Import

from nemo_aligner.models.mm.stable_diffusion.megatron_sdxl_draftp_model import MegatronSDXLDRaFTPModel
from nemo_aligner.models.mm.stable_diffusion.image_text_rms import MegatronCLIPRewardModel, get_reward_model
from nemo_aligner.algorithms.supervised import SupervisedTrainer

I/O Contract

Inputs

Name Type Required Description
cfg.model DictConfig Yes SDXL model configuration with sampling.base, truncation_steps, kl_coeff, peft, micro_batch_size, global_batch_size
cfg.model.data.webdataset.local_root_path str Yes Path to the webdataset containing text captions
cfg.rm DictConfig Yes Reward model configuration
cfg.trainer.draftp_sd DictConfig Yes DRaFT+ trainer-specific configuration
cfg.model.activation_checkpointing bool No Enable activation checkpointing on UNet and Decoder (default: False)
cfg.model.fsdp bool No Enable FSDP sharding (default: False)

Outputs

Name Type Description
Trained model checkpoint File Saved checkpoint of the DRaFT+ fine-tuned SDXL model
Training logs Logs Loss and KL penalty metrics logged via the experiment logger

Usage Examples

# Command-line invocation:
# python examples/mm/stable_diffusion/train_sdxl_draftp.py \
#     model.data.webdataset.local_root_path=/path/to/webdataset \
#     rm.model.restore_from_path=/path/to/reward_model.nemo \
#     model.truncation_steps=5 \
#     model.kl_coeff=0.1 \
#     model.fsdp=True \
#     model.activation_checkpointing=True \
#     model.peft.peft_scheme=lora

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment