Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Trl AutoModelForSequenceClassification From Pretrained

From Leeroopedia


Property Value
Implementation Name AutoModelForSequenceClassification From Pretrained
Technology Huggingface TRL, Transformers
Type API Doc
Workflow Reward Model Training
Principle Principle:Huggingface_Trl_Reward_Sequence_Classifier_Loading

Overview

Description

The reward model is loaded by the create_model_from_path utility function, which uses AutoModelForSequenceClassification as the architecture class. This produces a pretrained language model with a linear classification head configured for single-label regression (num_labels=1), outputting a scalar reward value per sequence. Within RewardTrainer.__init__, when the model argument is a string path, this function is called automatically.

Usage

Model loading is triggered internally when a string model path is passed to RewardTrainer. It can also be used directly via create_model_from_path for custom setups.

Code Reference

Source Location

  • create_model_from_path: trl/trainer/utils.py lines 1133-1169
  • RewardTrainer model loading: trl/trainer/reward_trainer.py lines 308-313

Signature

def create_model_from_path(
    model_id: str,
    architecture: _BaseAutoModelClass | None = None,
    **kwargs
) -> PreTrainedModel:
    """
    Create a model from a given path using the specified initialization arguments.

    Args:
        model_id: Path to the model (local or Hub identifier).
        architecture: Model architecture class (e.g., AutoModelForSequenceClassification).
        **kwargs: Keyword arguments passed to from_pretrained.

    Returns:
        The instantiated PreTrainedModel.
    """
# In RewardTrainer.__init__ (reward_trainer.py L308-313):
if isinstance(model, str):
    model_init_kwargs = args.model_init_kwargs or {}
    if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
        model_init_kwargs["device_map"] = None
    model = create_model_from_path(model, AutoModelForSequenceClassification, **model_init_kwargs)

Import

from transformers import AutoModelForSequenceClassification
from trl.trainer.utils import create_model_from_path

Key Configuration

The critical configuration is num_labels=1 which is set internally by the AutoModelForSequenceClassification architecture when the model is loaded for reward scoring. This configures the model for scalar reward output rather than multi-class classification.

I/O Contract

Inputs

Parameter Type Default Description
model_id str (required) Model path (local directory or Huggingface Hub model ID)
architecture _BaseAutoModelClass or None None Architecture class; set to AutoModelForSequenceClassification for reward models
dtype str or torch.dtype "float32" Data type for model weights; accepts "auto", "bfloat16", "float16", "float32"
device_map str or dict "auto" Device placement strategy; set to None for distributed training

Outputs

Output Type Description
model PreTrainedModel Loaded model with classification head (score layer) producing scalar rewards

Usage Examples

Direct Loading

from transformers import AutoModelForSequenceClassification
from trl.trainer.utils import create_model_from_path

# Load a pretrained model as a reward model
reward_model = create_model_from_path(
    "Qwen/Qwen2.5-0.5B-Instruct",
    architecture=AutoModelForSequenceClassification,
    num_labels=1,
    dtype="bfloat16",
)

Via RewardTrainer (Automatic Loading)

from trl import RewardTrainer, RewardConfig

# When a string is passed, RewardTrainer loads the model automatically
trainer = RewardTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    args=RewardConfig(output_dir="reward-output"),
    train_dataset=dataset,
)
# The model is now an AutoModelForSequenceClassification with num_labels=1

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment