Implementation:Huggingface Alignment handbook Get Model
| Knowledge Sources | |
|---|---|
| Domains | NLP, Model_Architecture, Deep_Learning |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for loading pretrained causal language models with configurable dtype, attention, and optional quantization, provided by the alignment-handbook library.
Description
The get_model function wraps AutoModelForCausalLM.from_pretrained with alignment-handbook-specific configuration. It resolves the torch_dtype string to a PyTorch dtype, obtains quantization config from TRL's get_quantization_config, and constructs the model keyword arguments including attention implementation and cache settings.
Usage
Import this function when loading a model for any alignment training script. It is called in every training script (sft.py, dpo.py, orpo.py) and handles both full-precision and quantized model loading.
Code Reference
Source Location
- Repository: alignment-handbook
- File: src/alignment/model_utils.py (lines 37-57)
Signature
def get_model(model_args: ModelConfig, training_args: SFTConfig) -> AutoModelForCausalLM:
"""Get the model.
Args:
model_args (ModelConfig): Model configuration from TRL, containing
model_name_or_path, torch_dtype, attn_implementation,
trust_remote_code, model_revision, and quantization flags.
training_args (SFTConfig): Training configuration, used to check
gradient_checkpointing for use_cache toggle.
Returns:
AutoModelForCausalLM: The loaded pretrained model.
"""
Import
from alignment import get_model
from trl import ModelConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_args | ModelConfig | Yes | Model configuration from TRL |
| model_args.model_name_or_path | str | Yes | HuggingFace model ID or local path (e.g., "mistralai/Mistral-7B-v0.1") |
| model_args.torch_dtype | str | No | Dtype string: "auto", "bfloat16", "float16", "float32" |
| model_args.attn_implementation | str | No | Attention backend: "flash_attention_2", "sdpa", or None |
| model_args.trust_remote_code | bool | No | Whether to trust remote code for custom architectures |
| model_args.model_revision | str | No | Git revision for the model (branch, tag, or commit hash) |
| training_args | SFTConfig | Yes | Training config (only gradient_checkpointing field is used) |
| training_args.gradient_checkpointing | bool | No | If True, sets use_cache=False on the model |
Outputs
| Name | Type | Description |
|---|---|---|
| return | AutoModelForCausalLM | Loaded pretrained model ready for fine-tuning, with appropriate dtype, attention implementation, and optional quantization applied |
Usage Examples
Standard Full-Precision Loading
from alignment import get_model
from trl import ModelConfig
# Model args from YAML config
# model_name_or_path: mistralai/Mistral-7B-v0.1
# torch_dtype: bfloat16
# attn_implementation: flash_attention_2
model = get_model(model_args, training_args)
print(model.dtype) # torch.bfloat16
print(model.config.use_cache) # False (if gradient_checkpointing=True)
Loading with Quantization (QLoRA)
from alignment import get_model
# When model_args.load_in_4bit = True:
# - get_quantization_config(model_args) returns BitsAndBytesConfig
# - get_kbit_device_map() returns appropriate device mapping
model = get_model(model_args, training_args)
# Model is now 4-bit quantized on GPU
print(model.is_quantized) # True
DPO: Loading Model and Reference Model
from alignment import get_model
# In DPO training, the model is loaded twice:
# once as the policy model, once as the frozen reference
model = get_model(model_args, training_args)
ref_model = get_model(model_args, training_args)