Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Eric mitchell Direct preference optimization AutoModelForCausalLM From Pretrained

From Leeroopedia


Knowledge Sources
Domains Model_Initialization, NLP
Last Updated 2026-02-08 02:00 GMT

Overview

Wrapper for the HuggingFace AutoModelForCausalLM.from_pretrained method as used in this repository for loading causal language models with specific dtype and device mapping.

Description

This repository uses transformers.AutoModelForCausalLM.from_pretrained to load pre-trained causal language models. The call is configured with:

  • torch_dtype: Set from config (e.g., float32 for policy, potentially different for reference)
  • device_map="balanced": Used by BasicTrainer to naively split model across available GPUs
  • low_cpu_mem_usage=True: Reduces peak CPU memory during model loading
  • cache_dir: Set to a local scratch directory for model weight caching

After loading, disable_dropout is called immediately to set all dropout probabilities to 0.

Usage

Use this when initializing models at the start of training. For SFT, one model is loaded as the policy. For DPO/IPO, two models are loaded: the policy (trainable) and the reference (frozen), both from the same pre-trained checkpoint.

Code Reference

Source Location

Signature

# Policy model loading (train.py:L83-85)
policy = transformers.AutoModelForCausalLM.from_pretrained(
    config.model.name_or_path,
    cache_dir=get_local_dir(config.local_dirs),
    low_cpu_mem_usage=True,
    torch_dtype=policy_dtype,
    **model_kwargs,  # {'device_map': 'balanced'} for BasicTrainer
)
disable_dropout(policy)

# Reference model loading (train.py:L89-92, only for DPO/IPO)
reference_model = transformers.AutoModelForCausalLM.from_pretrained(
    config.model.name_or_path,
    cache_dir=get_local_dir(config.local_dirs),
    low_cpu_mem_usage=True,
    torch_dtype=reference_model_dtype,
    **model_kwargs,
)
disable_dropout(reference_model)

Import

import transformers
from utils import disable_dropout, get_local_dir

I/O Contract

Inputs

Name Type Required Description
config.model.name_or_path str Yes HuggingFace model identifier (e.g., "EleutherAI/pythia-2.8b") or local path
config.model.policy_dtype str Yes Torch dtype string for policy model (e.g., "float32")
config.model.reference_dtype str Yes (DPO) Torch dtype string for reference model
config.local_dirs List[str] Yes List of directory prefixes for caching
config.trainer str Yes Trainer class name; determines device_map strategy

Outputs

Name Type Description
policy nn.Module Loaded causal LM with dropout disabled
reference_model Optional[nn.Module] Second copy for DPO/IPO (None for SFT)

Usage Examples

SFT Model Loading

import torch
import transformers
from utils import disable_dropout, get_local_dir

# Load single policy model for SFT
policy_dtype = getattr(torch, "float32")
model_kwargs = {'device_map': 'balanced'}
policy = transformers.AutoModelForCausalLM.from_pretrained(
    "EleutherAI/pythia-2.8b",
    cache_dir=get_local_dir(["/scr-ssd", "/scr", ".cache"]),
    low_cpu_mem_usage=True,
    torch_dtype=policy_dtype,
    **model_kwargs,
)
disable_dropout(policy)

DPO Dual Model Loading

# Load both policy and reference for DPO training
policy = transformers.AutoModelForCausalLM.from_pretrained(
    "EleutherAI/pythia-2.8b",
    cache_dir=cache_dir,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float32,
    device_map='balanced',
)
disable_dropout(policy)

reference_model = transformers.AutoModelForCausalLM.from_pretrained(
    "EleutherAI/pythia-2.8b",
    cache_dir=cache_dir,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float32,
    device_map='balanced',
)
disable_dropout(reference_model)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment