Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Trl AutoModelForCausalLM From Pretrained SFT

From Leeroopedia


Knowledge Sources
Domains NLP, Training
Last Updated 2026-02-06 17:00 GMT

Overview

Concrete pattern for loading a pretrained causal language model with optional quantization and attention kernel selection for SFT, provided by the TRL library and HuggingFace Transformers.

Description

The SFT script assembles a model_kwargs dictionary from the parsed ModelConfig and passes it to AutoModelForCausalLM.from_pretrained(). If the model architecture is a vision-language model (detected via MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES), the script uses AutoModelForImageTextToText instead. The get_quantization_config() utility converts boolean flags (load_in_4bit, load_in_8bit) into a BitsAndBytesConfig object, and get_kbit_device_map() provides the correct device mapping for quantized models.

Usage

Use this pattern when you need to load a model for SFT training, especially when combining quantization with LoRA adapters (QLoRA workflow).

Code Reference

Source Location

  • Repository: TRL
  • File: trl/scripts/sft.py (lines 91-116, model loading)
  • File: trl/trainer/utils.py (lines 283-299, get_quantization_config; lines 302-306, get_kbit_device_map)

Signature

# Pattern assembled in trl/scripts/sft.py:main()
def main(script_args, training_args, model_args, dataset_args):
    model_kwargs = dict(
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        attn_implementation=model_args.attn_implementation,
        dtype=model_args.dtype,
    )
    quantization_config = get_quantization_config(model_args)
    if quantization_config is not None:
        model_kwargs["device_map"] = get_kbit_device_map()
        model_kwargs["quantization_config"] = quantization_config

    config = AutoConfig.from_pretrained(model_args.model_name_or_path)
    if config.architectures and any(
        arch in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
        for arch in config.architectures
    ):
        model = AutoModelForImageTextToText.from_pretrained(
            model_args.model_name_or_path, **model_kwargs
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path, **model_kwargs
        )


# Helper functions
def get_quantization_config(model_args: ModelConfig) -> BitsAndBytesConfig | None:
    ...

def get_kbit_device_map() -> dict[str, int] | None:
    ...

Import

from transformers import AutoModelForCausalLM, AutoConfig
from trl import get_quantization_config, get_kbit_device_map, ModelConfig

I/O Contract

Inputs

Name Type Required Description
model_args.model_name_or_path str Yes HuggingFace Hub model ID or local path to model directory
model_args.dtype str No Precision: "auto", "bfloat16", "float16", or "float32" (default: "float32")
model_args.model_revision str No Git revision (branch, tag, or commit hash); default: "main"
model_args.trust_remote_code bool No Allow execution of custom modeling code from the Hub; default: False
model_args.attn_implementation None No Attention kernel: "flash_attention_2", "flash_attention_3", or None for default
model_args.load_in_4bit bool No Load base weights in 4-bit NF4 quantization; default: False
model_args.load_in_8bit bool No Load base weights in 8-bit quantization; default: False
model_args.bnb_4bit_quant_type str No Quantization type: "nf4" or "fp4"; default: "nf4"
model_args.use_bnb_nested_quant bool No Double quantization of quantization constants; default: False

Outputs

Name Type Description
model PreTrainedModel Loaded causal language model ready for fine-tuning or PEFT adapter wrapping

Usage Examples

Basic Usage (Full Precision)

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-0.5B",
    dtype="bfloat16",
    attn_implementation="flash_attention_2",
)

QLoRA Usage (4-bit Quantization)

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from trl import ModelConfig, get_quantization_config, get_kbit_device_map

model_args = ModelConfig(
    model_name_or_path="meta-llama/Llama-3.1-8B",
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    use_bnb_nested_quant=True,
    dtype="bfloat16",
)

quantization_config = get_quantization_config(model_args)
device_map = get_kbit_device_map()

model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    quantization_config=quantization_config,
    device_map=device_map,
    dtype=model_args.dtype,
)

Using the TRL SFT Script (CLI)

python trl/scripts/sft.py \
    --model_name_or_path Qwen/Qwen2-0.5B \
    --dataset_name trl-lib/Capybara \
    --load_in_4bit \
    --use_peft \
    --output_dir ./output

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment