Implementation:Princeton nlp SimPO Get Tokenizer and Model Config
| Knowledge Sources | |
|---|---|
| Domains | NLP, Model_Loading, Quantization |
| Last Updated | 2026-02-08 04:30 GMT |
Overview
Concrete tool for loading tokenizers and constructing quantization/LoRA configurations for SimPO training, provided by the SimPO alignment package.
Description
Three utility functions work together to prepare the model loading context: get_tokenizer loads and configures the tokenizer (pad token, truncation side, chat template), get_quantization_config builds a BitsAndBytesConfig for 4-bit or 8-bit quantization, and get_peft_config creates a LoraConfig for parameter-efficient training. A fourth helper, get_kbit_device_map, provides the device map needed for quantized model loading. These outputs are assembled into a model_kwargs dictionary that gets passed to SimPOTrainer via training_args.model_init_kwargs.
Usage
Call these functions after configuration parsing, before instantiating SimPOTrainer. The tokenizer is used immediately for chat template application; the quantization and PEFT configs are passed to the trainer.
Code Reference
Source Location
- Repository: SimPO
- File: alignment/model_utils.py (Lines 38-111)
Signature
def get_tokenizer(
model_args: ModelArguments,
data_args: DataArguments,
auto_set_chat_template: bool = True,
) -> PreTrainedTokenizer:
"""
Get the tokenizer for the model.
Handles:
- Loading from model_name_or_path or tokenizer_name_or_path
- Setting pad_token_id to eos_token_id if not defined
- Configuring truncation_side from data_args
- Capping model_max_length at 2048 if unreasonably large
- Setting chat template from data_args or DEFAULT_CHAT_TEMPLATE
"""
def get_quantization_config(
model_args: ModelArguments,
) -> BitsAndBytesConfig | None:
"""
Build BitsAndBytesConfig from model_args.
Returns None if neither load_in_4bit nor load_in_8bit is set.
"""
def get_peft_config(
model_args: ModelArguments,
) -> PeftConfig | None:
"""
Build LoraConfig from model_args.
Returns None if use_peft is False.
"""
def get_kbit_device_map() -> Dict[str, int] | None:
"""
Returns {"": local_process_index} for quantized model loading.
Returns None if CUDA is not available.
"""
Import
from alignment import (
get_tokenizer,
get_quantization_config,
get_peft_config,
get_kbit_device_map,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_args | ModelArguments | Yes | Model selection, quantization flags, LoRA parameters |
| data_args | DataArguments | Yes | Truncation side, chat template settings |
Outputs
| Name | Type | Description |
|---|---|---|
| tokenizer | PreTrainedTokenizer | Configured tokenizer with pad token, truncation, and chat template set |
| quantization_config | BitsAndBytesConfig or None | Quantization config for 4-bit/8-bit loading, or None |
| peft_config | PeftConfig or None | LoRA config for parameter-efficient training, or None |
| model_kwargs | dict | Assembled keyword arguments for AutoModelForCausalLM.from_pretrained |
Usage Examples
Standard Initialization
from alignment import (
get_tokenizer, get_quantization_config, get_peft_config,
get_kbit_device_map, ModelArguments, DataArguments,
)
import torch
# Force left truncation for preference tasks
data_args.truncation_side = "left"
# Load tokenizer
tokenizer = get_tokenizer(model_args, data_args)
# Build quantization config
quantization_config = get_quantization_config(model_args)
# Assemble model_kwargs for lazy loading in SimPOTrainer
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
attn_implementation=model_args.attn_implementation,
)
# Pass to training_args for SimPOTrainer
training_args.model_init_kwargs = model_kwargs
LoRA Configuration
# When use_peft=True in ModelArguments:
peft_config = get_peft_config(model_args)
# Returns LoraConfig(
# r=16,
# lora_alpha=32,
# lora_dropout=0.05,
# bias="none",
# task_type="CAUSAL_LM",
# target_modules=model_args.lora_target_modules,
# )
# Pass to SimPOTrainer
trainer = SimPOTrainer(
model=model_args.model_name_or_path,
args=training_args,
train_dataset=raw_datasets["train"],
eval_dataset=raw_datasets["test"],
tokenizer=tokenizer,
peft_config=peft_config,
)