Implementation:Huggingface Trl AutoModelForCausalLM From Pretrained SFT
| Knowledge Sources | |
|---|---|
| Domains | NLP, Training |
| Last Updated | 2026-02-06 17:00 GMT |
Overview
Concrete pattern for loading a pretrained causal language model with optional quantization and attention kernel selection for SFT, provided by the TRL library and HuggingFace Transformers.
Description
The SFT script assembles a model_kwargs dictionary from the parsed ModelConfig and passes it to AutoModelForCausalLM.from_pretrained(). If the model architecture is a vision-language model (detected via MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES), the script uses AutoModelForImageTextToText instead. The get_quantization_config() utility converts boolean flags (load_in_4bit, load_in_8bit) into a BitsAndBytesConfig object, and get_kbit_device_map() provides the correct device mapping for quantized models.
Usage
Use this pattern when you need to load a model for SFT training, especially when combining quantization with LoRA adapters (QLoRA workflow).
Code Reference
Source Location
- Repository: TRL
- File:
trl/scripts/sft.py(lines 91-116, model loading) - File:
trl/trainer/utils.py(lines 283-299,get_quantization_config; lines 302-306,get_kbit_device_map)
Signature
# Pattern assembled in trl/scripts/sft.py:main()
def main(script_args, training_args, model_args, dataset_args):
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
dtype=model_args.dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
if config.architectures and any(
arch in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
for arch in config.architectures
):
model = AutoModelForImageTextToText.from_pretrained(
model_args.model_name_or_path, **model_kwargs
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, **model_kwargs
)
# Helper functions
def get_quantization_config(model_args: ModelConfig) -> BitsAndBytesConfig | None:
...
def get_kbit_device_map() -> dict[str, int] | None:
...
Import
from transformers import AutoModelForCausalLM, AutoConfig
from trl import get_quantization_config, get_kbit_device_map, ModelConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_args.model_name_or_path | str |
Yes | HuggingFace Hub model ID or local path to model directory |
| model_args.dtype | str |
No | Precision: "auto", "bfloat16", "float16", or "float32" (default: "float32")
|
| model_args.model_revision | str |
No | Git revision (branch, tag, or commit hash); default: "main"
|
| model_args.trust_remote_code | bool |
No | Allow execution of custom modeling code from the Hub; default: False
|
| model_args.attn_implementation | None | No | Attention kernel: "flash_attention_2", "flash_attention_3", or None for default
|
| model_args.load_in_4bit | bool |
No | Load base weights in 4-bit NF4 quantization; default: False
|
| model_args.load_in_8bit | bool |
No | Load base weights in 8-bit quantization; default: False
|
| model_args.bnb_4bit_quant_type | str |
No | Quantization type: "nf4" or "fp4"; default: "nf4"
|
| model_args.use_bnb_nested_quant | bool |
No | Double quantization of quantization constants; default: False
|
Outputs
| Name | Type | Description |
|---|---|---|
| model | PreTrainedModel |
Loaded causal language model ready for fine-tuning or PEFT adapter wrapping |
Usage Examples
Basic Usage (Full Precision)
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-0.5B",
dtype="bfloat16",
attn_implementation="flash_attention_2",
)
QLoRA Usage (4-bit Quantization)
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from trl import ModelConfig, get_quantization_config, get_kbit_device_map
model_args = ModelConfig(
model_name_or_path="meta-llama/Llama-3.1-8B",
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
use_bnb_nested_quant=True,
dtype="bfloat16",
)
quantization_config = get_quantization_config(model_args)
device_map = get_kbit_device_map()
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
quantization_config=quantization_config,
device_map=device_map,
dtype=model_args.dtype,
)
Using the TRL SFT Script (CLI)
python trl/scripts/sft.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--load_in_4bit \
--use_peft \
--output_dir ./output