Implementation:Eric mitchell Direct preference optimization AutoModelForCausalLM From Pretrained
| Knowledge Sources | |
|---|---|
| Domains | Model_Initialization, NLP |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Wrapper for the HuggingFace AutoModelForCausalLM.from_pretrained method as used in this repository for loading causal language models with specific dtype and device mapping.
Description
This repository uses transformers.AutoModelForCausalLM.from_pretrained to load pre-trained causal language models. The call is configured with:
- torch_dtype: Set from config (e.g., float32 for policy, potentially different for reference)
- device_map="balanced": Used by BasicTrainer to naively split model across available GPUs
- low_cpu_mem_usage=True: Reduces peak CPU memory during model loading
- cache_dir: Set to a local scratch directory for model weight caching
After loading, disable_dropout is called immediately to set all dropout probabilities to 0.
Usage
Use this when initializing models at the start of training. For SFT, one model is loaded as the policy. For DPO/IPO, two models are loaded: the policy (trainable) and the reference (frozen), both from the same pre-trained checkpoint.
Code Reference
Source Location
- Repository: direct-preference-optimization
- File: train.py
- Lines: 80-92
Signature
# Policy model loading (train.py:L83-85)
policy = transformers.AutoModelForCausalLM.from_pretrained(
config.model.name_or_path,
cache_dir=get_local_dir(config.local_dirs),
low_cpu_mem_usage=True,
torch_dtype=policy_dtype,
**model_kwargs, # {'device_map': 'balanced'} for BasicTrainer
)
disable_dropout(policy)
# Reference model loading (train.py:L89-92, only for DPO/IPO)
reference_model = transformers.AutoModelForCausalLM.from_pretrained(
config.model.name_or_path,
cache_dir=get_local_dir(config.local_dirs),
low_cpu_mem_usage=True,
torch_dtype=reference_model_dtype,
**model_kwargs,
)
disable_dropout(reference_model)
Import
import transformers
from utils import disable_dropout, get_local_dir
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config.model.name_or_path | str | Yes | HuggingFace model identifier (e.g., "EleutherAI/pythia-2.8b") or local path |
| config.model.policy_dtype | str | Yes | Torch dtype string for policy model (e.g., "float32") |
| config.model.reference_dtype | str | Yes (DPO) | Torch dtype string for reference model |
| config.local_dirs | List[str] | Yes | List of directory prefixes for caching |
| config.trainer | str | Yes | Trainer class name; determines device_map strategy |
Outputs
| Name | Type | Description |
|---|---|---|
| policy | nn.Module | Loaded causal LM with dropout disabled |
| reference_model | Optional[nn.Module] | Second copy for DPO/IPO (None for SFT) |
Usage Examples
SFT Model Loading
import torch
import transformers
from utils import disable_dropout, get_local_dir
# Load single policy model for SFT
policy_dtype = getattr(torch, "float32")
model_kwargs = {'device_map': 'balanced'}
policy = transformers.AutoModelForCausalLM.from_pretrained(
"EleutherAI/pythia-2.8b",
cache_dir=get_local_dir(["/scr-ssd", "/scr", ".cache"]),
low_cpu_mem_usage=True,
torch_dtype=policy_dtype,
**model_kwargs,
)
disable_dropout(policy)
DPO Dual Model Loading
# Load both policy and reference for DPO training
policy = transformers.AutoModelForCausalLM.from_pretrained(
"EleutherAI/pythia-2.8b",
cache_dir=cache_dir,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
device_map='balanced',
)
disable_dropout(policy)
reference_model = transformers.AutoModelForCausalLM.from_pretrained(
"EleutherAI/pythia-2.8b",
cache_dir=cache_dir,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
device_map='balanced',
)
disable_dropout(reference_model)