Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner Train SD DRaFTP

From Leeroopedia


Knowledge Sources
Domains Multimodal, Image Generation, Diffusion Models, Alignment
Last Updated 2026-02-08 00:00 GMT

Overview

train_sd_draftp.py is the entry-point training script for DRaFT+ fine-tuning of Stable Diffusion models, orchestrating model instantiation, reward model loading, dataset construction, and supervised training.

Description

This script sets up the complete DRaFT+ training pipeline for standard Stable Diffusion models:

  1. Configuration -- Uses Hydra with config path "conf" and config name "draftp_sd". Enables TF32 for CUDA matmul on PyTorch >= 1.12. Sets dataset paths from the webdataset configuration.
  2. Trainer creation -- Creates a trainer using MegatronStableDiffusionTrainerBuilder via resolve_and_create_trainer, which temporarily removes the "draftp_sd" key from the trainer config before construction.
  3. Model instantiation -- Creates a MegatronSDDRaFTPModel directly (not loaded from checkpoint) and optionally initializes PEFT adapters.
  4. Reward model loading -- Loads a CLIP-based reward model via get_reward_model and attaches it to the model as ptl_model.reward_model.
  5. Dataset construction -- Builds train and validation datasets from text_webdataset, extracting only the "captions" field since DRaFT+ only needs text prompts (images are generated during training).
  6. Training execution -- Creates a SupervisedTrainer with the "draftp_sd" configuration section and calls fit().

Usage

Run this script to perform DRaFT+ training on standard Stable Diffusion models. Requires a text-caption webdataset and a pretrained reward model checkpoint.

Code Reference

Source Location

  • Repository: NVIDIA_NeMo_Aligner
  • File: examples/mm/stable_diffusion/train_sd_draftp.py
  • Lines: 1-146

Signature

def resolve_and_create_trainer(cfg, pop_trainer_key):

@hydra_runner(config_path="conf", config_name="draftp_sd")
def main(cfg) -> None:

Import

from nemo_aligner.models.mm.stable_diffusion.megatron_sd_draftp_model import MegatronSDDRaFTPModel
from nemo_aligner.models.mm.stable_diffusion.image_text_rms import get_reward_model
from nemo_aligner.algorithms.supervised import SupervisedTrainer

I/O Contract

Inputs

Name Type Required Description
cfg.model DictConfig Yes Stable Diffusion model configuration including infer, truncation_steps, kl_coeff, micro_batch_size, global_batch_size
cfg.model.data.webdataset.local_root_path str Yes Path to the webdataset containing text captions
cfg.rm DictConfig Yes Reward model configuration including checkpoint path and trainer precision
cfg.trainer.draftp_sd DictConfig Yes DRaFT+ trainer-specific configuration (max_steps, val_check_interval, etc.)

Outputs

Name Type Description
Trained model checkpoint File Saved checkpoint of the DRaFT+ fine-tuned SD model
Training logs Logs Loss and KL penalty metrics logged via the experiment logger

Usage Examples

# Command-line invocation:
# python examples/mm/stable_diffusion/train_sd_draftp.py \
#     model.data.webdataset.local_root_path=/path/to/webdataset \
#     rm.model.restore_from_path=/path/to/reward_model.nemo \
#     model.truncation_steps=5 \
#     model.kl_coeff=0.1

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment