Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Alignment handbook Get Model

From Leeroopedia


Knowledge Sources
Domains NLP, Model_Architecture, Deep_Learning
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for loading pretrained causal language models with configurable dtype, attention, and optional quantization, provided by the alignment-handbook library.

Description

The get_model function wraps AutoModelForCausalLM.from_pretrained with alignment-handbook-specific configuration. It resolves the torch_dtype string to a PyTorch dtype, obtains quantization config from TRL's get_quantization_config, and constructs the model keyword arguments including attention implementation and cache settings.

Usage

Import this function when loading a model for any alignment training script. It is called in every training script (sft.py, dpo.py, orpo.py) and handles both full-precision and quantized model loading.

Code Reference

Source Location

Signature

def get_model(model_args: ModelConfig, training_args: SFTConfig) -> AutoModelForCausalLM:
    """Get the model.

    Args:
        model_args (ModelConfig): Model configuration from TRL, containing
            model_name_or_path, torch_dtype, attn_implementation,
            trust_remote_code, model_revision, and quantization flags.
        training_args (SFTConfig): Training configuration, used to check
            gradient_checkpointing for use_cache toggle.

    Returns:
        AutoModelForCausalLM: The loaded pretrained model.
    """

Import

from alignment import get_model
from trl import ModelConfig

I/O Contract

Inputs

Name Type Required Description
model_args ModelConfig Yes Model configuration from TRL
model_args.model_name_or_path str Yes HuggingFace model ID or local path (e.g., "mistralai/Mistral-7B-v0.1")
model_args.torch_dtype str No Dtype string: "auto", "bfloat16", "float16", "float32"
model_args.attn_implementation str No Attention backend: "flash_attention_2", "sdpa", or None
model_args.trust_remote_code bool No Whether to trust remote code for custom architectures
model_args.model_revision str No Git revision for the model (branch, tag, or commit hash)
training_args SFTConfig Yes Training config (only gradient_checkpointing field is used)
training_args.gradient_checkpointing bool No If True, sets use_cache=False on the model

Outputs

Name Type Description
return AutoModelForCausalLM Loaded pretrained model ready for fine-tuning, with appropriate dtype, attention implementation, and optional quantization applied

Usage Examples

Standard Full-Precision Loading

from alignment import get_model
from trl import ModelConfig

# Model args from YAML config
# model_name_or_path: mistralai/Mistral-7B-v0.1
# torch_dtype: bfloat16
# attn_implementation: flash_attention_2
model = get_model(model_args, training_args)

print(model.dtype)  # torch.bfloat16
print(model.config.use_cache)  # False (if gradient_checkpointing=True)

Loading with Quantization (QLoRA)

from alignment import get_model

# When model_args.load_in_4bit = True:
# - get_quantization_config(model_args) returns BitsAndBytesConfig
# - get_kbit_device_map() returns appropriate device mapping
model = get_model(model_args, training_args)

# Model is now 4-bit quantized on GPU
print(model.is_quantized)  # True

DPO: Loading Model and Reference Model

from alignment import get_model

# In DPO training, the model is loaded twice:
# once as the policy model, once as the frozen reference
model = get_model(model_args, training_args)
ref_model = get_model(model_args, training_args)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment