Implementation:NVIDIA NeMo Aligner Anneal SDXL
| Knowledge Sources | |
|---|---|
| Domains | Multimodal, Image Generation, Diffusion Models, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
anneal_sdxl.py is an annealed sampling script for SDXL DRaFT+ models that generates images using weighted interpolation between the base model and the fine-tuned model at each denoising step, supporting multiple weighing strategies.
Description
This script performs annealed guidance sampling from a DRaFT+ fine-tuned SDXL model. Rather than using only the fine-tuned model for generation, it interpolates between the base (original) model's score function and the fine-tuned model's score function at each denoising step. This provides a controllable tradeoff between the diversity of the base model and the reward-aligned outputs of the fine-tuned model.
The script supports the following weighing strategies (controllable via cfg.weight_type):
- base -- Uses only the base model (weight = 0 for fine-tuned model at all steps)
- draft -- Uses only the fine-tuned model (weight = 1 at all steps)
- linear -- Linearly increases the fine-tuned model's weight from 0 to 1 over the denoising process:
- power_p -- Uses a power function: (e.g., power_2, power_4)
- step_f -- Uses a step function that switches to the fine-tuned model at fraction f: (e.g., step_0.6)
The default set of weight types is ["base", "draft", "linear", "power_2", "power_4", "step_0.6"] when no override is provided.
For each weighing type, the script:
- Iterates over the validation dataloader
- Generates random latents (with reproducible seeds per rank)
- Calls ptl_model.annealed_guidance with the chosen weighing function
- Saves generated images as PNG files and prompts as text files to an output directory organized by weight type
The script also supports a custom single prompt override via cfg.prompt.
The script includes the same MegatronStableDiffusionTrainerBuilder with FSDP support as the training script, and supports activation checkpointing for memory optimization.
Usage
Run this script after DRaFT+ training to generate annealed samples across multiple weighing strategies. Useful for evaluating the alignment-diversity tradeoff and selecting the optimal weighing function for deployment.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: examples/mm/stable_diffusion/anneal_sdxl.py
- Lines: 1-325
Signature
class MegatronStableDiffusionTrainerBuilder(MegatronTrainerBuilder):
def _training_strategy(self) -> NLPDDPStrategy:
def resolve_and_create_trainer(cfg, pop_trainer_key):
@hydra_runner(config_path="conf", config_name="draftp_sdxl")
def main(cfg) -> None:
Import
from nemo_aligner.models.mm.stable_diffusion.megatron_sdxl_draftp_model import MegatronSDXLDRaFTPModel
from nemo_aligner.models.mm.stable_diffusion.image_text_rms import MegatronCLIPRewardModel, get_reward_model
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg.model | DictConfig | Yes | SDXL model configuration matching the trained DRaFT+ model |
| cfg.model.data.webdataset.local_root_path | str | Yes | Path to the webdataset containing text captions for generation |
| cfg.weight_type | str or list | No | Comma-separated weighing types or list; defaults to ["base", "draft", "linear", "power_2", "power_4", "step_0.6"] |
| cfg.prompt | str | No | Optional single prompt to override the validation dataset |
| cfg.exp_manager.explicit_log_dir | str | Yes | Base output directory for generated images |
Outputs
| Name | Type | Description |
|---|---|---|
| Generated images | PNG files | Saved to {explicit_log_dir}/annealed_outputs_sdxl_{weight_type}/img_{idx}_{rank}.png |
| Prompts | Text files | Saved to {explicit_log_dir}/annealed_outputs_sdxl_{weight_type}/prompt_{idx}_{rank}.txt |
Usage Examples
# Command-line invocation with default weight types:
# python examples/mm/stable_diffusion/anneal_sdxl.py \
# model.data.webdataset.local_root_path=/path/to/webdataset \
# exp_manager.explicit_log_dir=/path/to/output
# With custom weight types:
# python examples/mm/stable_diffusion/anneal_sdxl.py \
# model.data.webdataset.local_root_path=/path/to/webdataset \
# weight_type="linear,power_2,step_0.5"
# With a single custom prompt:
# python examples/mm/stable_diffusion/anneal_sdxl.py \
# prompt="A photo of a cat sitting on a mountain at sunset"