Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Princeton nlp SimPO Get Tokenizer and Model Config

From Leeroopedia


Knowledge Sources
Domains NLP, Model_Loading, Quantization
Last Updated 2026-02-08 04:30 GMT

Overview

Concrete tool for loading tokenizers and constructing quantization/LoRA configurations for SimPO training, provided by the SimPO alignment package.

Description

Three utility functions work together to prepare the model loading context: get_tokenizer loads and configures the tokenizer (pad token, truncation side, chat template), get_quantization_config builds a BitsAndBytesConfig for 4-bit or 8-bit quantization, and get_peft_config creates a LoraConfig for parameter-efficient training. A fourth helper, get_kbit_device_map, provides the device map needed for quantized model loading. These outputs are assembled into a model_kwargs dictionary that gets passed to SimPOTrainer via training_args.model_init_kwargs.

Usage

Call these functions after configuration parsing, before instantiating SimPOTrainer. The tokenizer is used immediately for chat template application; the quantization and PEFT configs are passed to the trainer.

Code Reference

Source Location

  • Repository: SimPO
  • File: alignment/model_utils.py (Lines 38-111)

Signature

def get_tokenizer(
    model_args: ModelArguments,
    data_args: DataArguments,
    auto_set_chat_template: bool = True,
) -> PreTrainedTokenizer:
    """
    Get the tokenizer for the model.

    Handles:
    - Loading from model_name_or_path or tokenizer_name_or_path
    - Setting pad_token_id to eos_token_id if not defined
    - Configuring truncation_side from data_args
    - Capping model_max_length at 2048 if unreasonably large
    - Setting chat template from data_args or DEFAULT_CHAT_TEMPLATE
    """

def get_quantization_config(
    model_args: ModelArguments,
) -> BitsAndBytesConfig | None:
    """
    Build BitsAndBytesConfig from model_args.
    Returns None if neither load_in_4bit nor load_in_8bit is set.
    """

def get_peft_config(
    model_args: ModelArguments,
) -> PeftConfig | None:
    """
    Build LoraConfig from model_args.
    Returns None if use_peft is False.
    """

def get_kbit_device_map() -> Dict[str, int] | None:
    """
    Returns {"": local_process_index} for quantized model loading.
    Returns None if CUDA is not available.
    """

Import

from alignment import (
    get_tokenizer,
    get_quantization_config,
    get_peft_config,
    get_kbit_device_map,
)

I/O Contract

Inputs

Name Type Required Description
model_args ModelArguments Yes Model selection, quantization flags, LoRA parameters
data_args DataArguments Yes Truncation side, chat template settings

Outputs

Name Type Description
tokenizer PreTrainedTokenizer Configured tokenizer with pad token, truncation, and chat template set
quantization_config BitsAndBytesConfig or None Quantization config for 4-bit/8-bit loading, or None
peft_config PeftConfig or None LoRA config for parameter-efficient training, or None
model_kwargs dict Assembled keyword arguments for AutoModelForCausalLM.from_pretrained

Usage Examples

Standard Initialization

from alignment import (
    get_tokenizer, get_quantization_config, get_peft_config,
    get_kbit_device_map, ModelArguments, DataArguments,
)
import torch

# Force left truncation for preference tasks
data_args.truncation_side = "left"

# Load tokenizer
tokenizer = get_tokenizer(model_args, data_args)

# Build quantization config
quantization_config = get_quantization_config(model_args)

# Assemble model_kwargs for lazy loading in SimPOTrainer
torch_dtype = (
    model_args.torch_dtype if model_args.torch_dtype in ["auto", None]
    else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
    revision=model_args.model_revision,
    trust_remote_code=model_args.trust_remote_code,
    torch_dtype=torch_dtype,
    use_cache=False if training_args.gradient_checkpointing else True,
    device_map=get_kbit_device_map() if quantization_config is not None else None,
    quantization_config=quantization_config,
    attn_implementation=model_args.attn_implementation,
)

# Pass to training_args for SimPOTrainer
training_args.model_init_kwargs = model_kwargs

LoRA Configuration

# When use_peft=True in ModelArguments:
peft_config = get_peft_config(model_args)
# Returns LoraConfig(
#     r=16,
#     lora_alpha=32,
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM",
#     target_modules=model_args.lora_target_modules,
# )

# Pass to SimPOTrainer
trainer = SimPOTrainer(
    model=model_args.model_name_or_path,
    args=training_args,
    train_dataset=raw_datasets["train"],
    eval_dataset=raw_datasets["test"],
    tokenizer=tokenizer,
    peft_config=peft_config,
)

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment