Implementation:Huggingface Trl AutoModelForCausalLM From Pretrained DPO
| Knowledge Sources | |
|---|---|
| Domains | NLP, RLHF |
| Last Updated | 2026-02-06 17:00 GMT |
Overview
Concrete tool for loading and initializing the DPO policy model from a pretrained checkpoint, provided by the Transformers library and orchestrated by the TRL DPO script.
Description
The DPO script loads the policy model using AutoModelForCausalLM.from_pretrained with parameters derived from the ModelConfig dataclass. The loading pattern constructs a keyword arguments dictionary that includes the model revision, attention implementation, data type, and optionally quantization configuration and device map. Quantization (4-bit or 8-bit) is applied via the get_quantization_config helper, which reads settings from ModelConfig and returns a BitsAndBytesConfig when applicable. When quantization is active, get_kbit_device_map() provides the appropriate device placement strategy.
Usage
Use this pattern when:
- Loading the initial policy model for DPO training
- Loading a model with specific precision (bf16, fp16, auto)
- Loading a quantized model for memory-efficient DPO training
- Loading a community model with custom code from the Hub
Code Reference
Source Location
- Repository: TRL
- File:
trl/scripts/dpo.py(lines 93-107)
Signature
# Policy model loading pattern from trl/scripts/dpo.py
dtype = (
model_args.dtype
if model_args.dtype in ["auto", None]
else getattr(torch, model_args.dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
Import
from transformers import AutoModelForCausalLM
from trl import ModelConfig, get_quantization_config, get_kbit_device_map
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_name_or_path | str |
Yes | Hugging Face model ID or local path to the pretrained model checkpoint |
| model_revision | str or None |
No | Specific model revision (branch, tag, or commit hash) to load |
| attn_implementation | str or None |
No | Attention backend: None (default), "flash_attention_2", or "sdpa" |
| dtype | str or None |
No | Data type for model weights: "auto", "bfloat16", "float16", "float32" |
| trust_remote_code | bool |
No (default: False) | Whether to trust and execute custom model code from the Hub |
| quantization_config | BitsAndBytesConfig or None |
No | Quantization configuration for 4-bit or 8-bit loading |
Outputs
| Name | Type | Description |
|---|---|---|
| model | PreTrainedModel |
The loaded causal language model ready to serve as the DPO policy model |
Usage Examples
# Example 1: Basic model loading for DPO
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-0.5B-Instruct",
dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
# Example 2: Loading with quantization for memory efficiency
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
quantization_config=quantization_config,
device_map="auto",
dtype=torch.bfloat16,
)
# Example 3: Full DPO script pattern with ModelConfig
from trl import ModelConfig, get_quantization_config, get_kbit_device_map
model_args = ModelConfig(
model_name_or_path="Qwen/Qwen2-0.5B-Instruct",
attn_implementation="flash_attention_2",
)
dtype = (
model_args.dtype
if model_args.dtype in ["auto", None]
else getattr(torch, model_args.dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)