Implementation:NVIDIA NeMo Aligner MegatronSD DRaFTP Model
| Knowledge Sources | |
|---|---|
| Domains | Multimodal, Image Generation, Diffusion Models, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
MegatronSDDRaFTPModel implements the DRaFT+ (Differentiable Reward Fine-Tuning Plus) alignment method for Stable Diffusion models, supporting truncated backpropagation through denoising steps, KL regularization, and annealed sampling.
Description
The MegatronSDDRaFTPModel class extends MegatronLatentDiffusion and implements SupervisedInterface to integrate with NeMo Aligner's training framework.
During initialization, the model:
- Creates a frozen copy of the base model (self.init_model) as a LatentDiffusion instance for KL penalty computation.
- Freezes the first_stage_model (VAE encoder/decoder) since only the U-Net diffusion model is fine-tuned.
- Configures image generation parameters: height, width, downsampling_factor, in_channels, unconditional_guidance_scale, sampler_type (DDIM), inference_steps, and eta.
The generate method implements the core DRaFT+ training forward pass:
- Initializes a DDIM sampler and encodes text prompts into conditional and unconditional embeddings.
- Iterates through denoising timesteps. For steps before the truncation point, runs with torch.no_grad(). For the final truncation_steps, computes gradients through the fine-tuned model while collecting noise predictions from both the fine-tuned and base models.
- Decodes the final latent through the VAE decoder and clips to [0, 255].
- Returns the decoded image, fine-tuned epsilon predictions, and base model epsilon predictions.
The annealed_guidance method supports inference-time control by interpolating between base and fine-tuned model score functions at each step using a configurable weighing function.
The log_visualization method generates comparison images from both the DRaFT+ model and the base model for logging.
Usage
Instantiate this model for DRaFT+ training on Stable Diffusion (non-XL). Assign a reward model to ptl_model.reward_model before training. Use with SupervisedTrainer.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py
- Lines: 54-478
Signature
class MegatronSDDRaFTPModel(MegatronLatentDiffusion, SupervisedInterface):
def __init__(self, cfg, trainer):
def generate(self, batch, x_T):
def annealed_guidance(self, batch, x_T, weighing_fn=None):
def get_forward_output_and_loss_func(self, validation_step=False):
def get_loss_and_metrics(self, batch, forward_only=False):
Import
from nemo_aligner.models.mm.stable_diffusion.megatron_sd_draftp_model import MegatronSDDRaFTPModel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg | DictConfig | Yes | Model configuration with infer (height, width, down_factor, unconditional_guidance_scale, sampler_type, inference_steps, eta), truncation_steps, kl_coeff, micro_batch_size, global_batch_size |
| trainer | Trainer | Yes | PyTorch Lightning Trainer instance |
| batch | list[str] | Yes | List of text prompts for image generation |
| x_T | Tensor | Yes | Initial noise latents of shape [B, C, H//f, W//f] |
Outputs
| Name | Type | Description |
|---|---|---|
| loss_value | float | Negative mean reward plus KL penalty |
| metrics | dict | Dictionary with "loss" and "kl_penalty" values; during validation also includes "images_and_captions" |
| vae_decoder_output | Tensor | Generated images scaled to [0, 255] of shape [B, C, H, W] (from generate) |
| t_eps_draft_p | Tensor | Stacked noise predictions from fine-tuned model over truncated steps |
| t_eps_init | Tensor | Stacked noise predictions from base model over truncated steps |
Usage Examples
from nemo_aligner.models.mm.stable_diffusion.megatron_sd_draftp_model import MegatronSDDRaFTPModel
from nemo_aligner.models.mm.stable_diffusion.image_text_rms import get_reward_model
# Instantiate model
ptl_model = MegatronSDDRaFTPModel(cfg.model, trainer).to(torch.cuda.current_device())
# Attach reward model
reward_model = get_reward_model(cfg.rm, mbs=cfg.model.micro_batch_size, gbs=cfg.model.global_batch_size)
ptl_model.reward_model = reward_model
# Training step
loss_value, metrics = ptl_model.get_loss_and_metrics(batch, forward_only=False)
# Annealed inference
images = ptl_model.annealed_guidance(prompts, latents, weighing_fn=lambda s1, s2, i, t: i / t)