Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Trl AutoModelForCausalLM From Pretrained DPO

From Leeroopedia


Knowledge Sources
Domains NLP, RLHF
Last Updated 2026-02-06 17:00 GMT

Overview

Concrete tool for loading and initializing the DPO policy model from a pretrained checkpoint, provided by the Transformers library and orchestrated by the TRL DPO script.

Description

The DPO script loads the policy model using AutoModelForCausalLM.from_pretrained with parameters derived from the ModelConfig dataclass. The loading pattern constructs a keyword arguments dictionary that includes the model revision, attention implementation, data type, and optionally quantization configuration and device map. Quantization (4-bit or 8-bit) is applied via the get_quantization_config helper, which reads settings from ModelConfig and returns a BitsAndBytesConfig when applicable. When quantization is active, get_kbit_device_map() provides the appropriate device placement strategy.

Usage

Use this pattern when:

  • Loading the initial policy model for DPO training
  • Loading a model with specific precision (bf16, fp16, auto)
  • Loading a quantized model for memory-efficient DPO training
  • Loading a community model with custom code from the Hub

Code Reference

Source Location

  • Repository: TRL
  • File: trl/scripts/dpo.py (lines 93-107)

Signature

# Policy model loading pattern from trl/scripts/dpo.py
dtype = (
    model_args.dtype
    if model_args.dtype in ["auto", None]
    else getattr(torch, model_args.dtype)
)

model_kwargs = dict(
    revision=model_args.model_revision,
    attn_implementation=model_args.attn_implementation,
    dtype=dtype,
)

quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
    model_kwargs["device_map"] = get_kbit_device_map()
    model_kwargs["quantization_config"] = quantization_config

model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    trust_remote_code=model_args.trust_remote_code,
    **model_kwargs,
)

Import

from transformers import AutoModelForCausalLM
from trl import ModelConfig, get_quantization_config, get_kbit_device_map

I/O Contract

Inputs

Name Type Required Description
model_name_or_path str Yes Hugging Face model ID or local path to the pretrained model checkpoint
model_revision str or None No Specific model revision (branch, tag, or commit hash) to load
attn_implementation str or None No Attention backend: None (default), "flash_attention_2", or "sdpa"
dtype str or None No Data type for model weights: "auto", "bfloat16", "float16", "float32"
trust_remote_code bool No (default: False) Whether to trust and execute custom model code from the Hub
quantization_config BitsAndBytesConfig or None No Quantization configuration for 4-bit or 8-bit loading

Outputs

Name Type Description
model PreTrainedModel The loaded causal language model ready to serve as the DPO policy model

Usage Examples

# Example 1: Basic model loading for DPO
import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-0.5B-Instruct",
    dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
# Example 2: Loading with quantization for memory efficiency
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    quantization_config=quantization_config,
    device_map="auto",
    dtype=torch.bfloat16,
)
# Example 3: Full DPO script pattern with ModelConfig
from trl import ModelConfig, get_quantization_config, get_kbit_device_map

model_args = ModelConfig(
    model_name_or_path="Qwen/Qwen2-0.5B-Instruct",
    attn_implementation="flash_attention_2",
)

dtype = (
    model_args.dtype
    if model_args.dtype in ["auto", None]
    else getattr(torch, model_args.dtype)
)
model_kwargs = dict(
    revision=model_args.model_revision,
    attn_implementation=model_args.attn_implementation,
    dtype=dtype,
)

quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
    model_kwargs["device_map"] = get_kbit_device_map()
    model_kwargs["quantization_config"] = quantization_config

model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    trust_remote_code=model_args.trust_remote_code,
    **model_kwargs,
)

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment