Implementation:NVIDIA NeMo Aligner Train SD DRaFTP
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Multimodal, Image Generation, Diffusion Models, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
train_sd_draftp.py is the entry-point training script for DRaFT+ fine-tuning of Stable Diffusion models, orchestrating model instantiation, reward model loading, dataset construction, and supervised training.
Description
This script sets up the complete DRaFT+ training pipeline for standard Stable Diffusion models:
- Configuration -- Uses Hydra with config path "conf" and config name "draftp_sd". Enables TF32 for CUDA matmul on PyTorch >= 1.12. Sets dataset paths from the webdataset configuration.
- Trainer creation -- Creates a trainer using MegatronStableDiffusionTrainerBuilder via resolve_and_create_trainer, which temporarily removes the "draftp_sd" key from the trainer config before construction.
- Model instantiation -- Creates a MegatronSDDRaFTPModel directly (not loaded from checkpoint) and optionally initializes PEFT adapters.
- Reward model loading -- Loads a CLIP-based reward model via get_reward_model and attaches it to the model as ptl_model.reward_model.
- Dataset construction -- Builds train and validation datasets from text_webdataset, extracting only the "captions" field since DRaFT+ only needs text prompts (images are generated during training).
- Training execution -- Creates a SupervisedTrainer with the "draftp_sd" configuration section and calls fit().
Usage
Run this script to perform DRaFT+ training on standard Stable Diffusion models. Requires a text-caption webdataset and a pretrained reward model checkpoint.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: examples/mm/stable_diffusion/train_sd_draftp.py
- Lines: 1-146
Signature
def resolve_and_create_trainer(cfg, pop_trainer_key):
@hydra_runner(config_path="conf", config_name="draftp_sd")
def main(cfg) -> None:
Import
from nemo_aligner.models.mm.stable_diffusion.megatron_sd_draftp_model import MegatronSDDRaFTPModel
from nemo_aligner.models.mm.stable_diffusion.image_text_rms import get_reward_model
from nemo_aligner.algorithms.supervised import SupervisedTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg.model | DictConfig | Yes | Stable Diffusion model configuration including infer, truncation_steps, kl_coeff, micro_batch_size, global_batch_size |
| cfg.model.data.webdataset.local_root_path | str | Yes | Path to the webdataset containing text captions |
| cfg.rm | DictConfig | Yes | Reward model configuration including checkpoint path and trainer precision |
| cfg.trainer.draftp_sd | DictConfig | Yes | DRaFT+ trainer-specific configuration (max_steps, val_check_interval, etc.) |
Outputs
| Name | Type | Description |
|---|---|---|
| Trained model checkpoint | File | Saved checkpoint of the DRaFT+ fine-tuned SD model |
| Training logs | Logs | Loss and KL penalty metrics logged via the experiment logger |
Usage Examples
# Command-line invocation:
# python examples/mm/stable_diffusion/train_sd_draftp.py \
# model.data.webdataset.local_root_path=/path/to/webdataset \
# rm.model.restore_from_path=/path/to/reward_model.nemo \
# model.truncation_steps=5 \
# model.kl_coeff=0.1
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment