Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner MegatronSD DRaFTP Model

From Leeroopedia


Knowledge Sources
Domains Multimodal, Image Generation, Diffusion Models, Alignment
Last Updated 2026-02-08 00:00 GMT

Overview

MegatronSDDRaFTPModel implements the DRaFT+ (Differentiable Reward Fine-Tuning Plus) alignment method for Stable Diffusion models, supporting truncated backpropagation through denoising steps, KL regularization, and annealed sampling.

Description

The MegatronSDDRaFTPModel class extends MegatronLatentDiffusion and implements SupervisedInterface to integrate with NeMo Aligner's training framework.

During initialization, the model:

  • Creates a frozen copy of the base model (self.init_model) as a LatentDiffusion instance for KL penalty computation.
  • Freezes the first_stage_model (VAE encoder/decoder) since only the U-Net diffusion model is fine-tuned.
  • Configures image generation parameters: height, width, downsampling_factor, in_channels, unconditional_guidance_scale, sampler_type (DDIM), inference_steps, and eta.

The generate method implements the core DRaFT+ training forward pass:

  1. Initializes a DDIM sampler and encodes text prompts into conditional and unconditional embeddings.
  2. Iterates through denoising timesteps. For steps before the truncation point, runs with torch.no_grad(). For the final truncation_steps, computes gradients through the fine-tuned model while collecting noise predictions from both the fine-tuned and base models.
  3. Decodes the final latent through the VAE decoder and clips to [0, 255].
  4. Returns the decoded image, fine-tuned epsilon predictions, and base model epsilon predictions.

The annealed_guidance method supports inference-time control by interpolating between base and fine-tuned model score functions at each step using a configurable weighing function.

The log_visualization method generates comparison images from both the DRaFT+ model and the base model for logging.

Usage

Instantiate this model for DRaFT+ training on Stable Diffusion (non-XL). Assign a reward model to ptl_model.reward_model before training. Use with SupervisedTrainer.

Code Reference

Source Location

  • Repository: NVIDIA_NeMo_Aligner
  • File: nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py
  • Lines: 54-478

Signature

class MegatronSDDRaFTPModel(MegatronLatentDiffusion, SupervisedInterface):
    def __init__(self, cfg, trainer):
    def generate(self, batch, x_T):
    def annealed_guidance(self, batch, x_T, weighing_fn=None):
    def get_forward_output_and_loss_func(self, validation_step=False):
    def get_loss_and_metrics(self, batch, forward_only=False):

Import

from nemo_aligner.models.mm.stable_diffusion.megatron_sd_draftp_model import MegatronSDDRaFTPModel

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes Model configuration with infer (height, width, down_factor, unconditional_guidance_scale, sampler_type, inference_steps, eta), truncation_steps, kl_coeff, micro_batch_size, global_batch_size
trainer Trainer Yes PyTorch Lightning Trainer instance
batch list[str] Yes List of text prompts for image generation
x_T Tensor Yes Initial noise latents of shape [B, C, H//f, W//f]

Outputs

Name Type Description
loss_value float Negative mean reward plus KL penalty
metrics dict Dictionary with "loss" and "kl_penalty" values; during validation also includes "images_and_captions"
vae_decoder_output Tensor Generated images scaled to [0, 255] of shape [B, C, H, W] (from generate)
t_eps_draft_p Tensor Stacked noise predictions from fine-tuned model over truncated steps
t_eps_init Tensor Stacked noise predictions from base model over truncated steps

Usage Examples

from nemo_aligner.models.mm.stable_diffusion.megatron_sd_draftp_model import MegatronSDDRaFTPModel
from nemo_aligner.models.mm.stable_diffusion.image_text_rms import get_reward_model

# Instantiate model
ptl_model = MegatronSDDRaFTPModel(cfg.model, trainer).to(torch.cuda.current_device())

# Attach reward model
reward_model = get_reward_model(cfg.rm, mbs=cfg.model.micro_batch_size, gbs=cfg.model.global_batch_size)
ptl_model.reward_model = reward_model

# Training step
loss_value, metrics = ptl_model.get_loss_and_metrics(batch, forward_only=False)

# Annealed inference
images = ptl_model.annealed_guidance(prompts, latents, weighing_fn=lambda s1, s2, i, t: i / t)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment