Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Transformers Trainer Train

From Leeroopedia
Knowledge Sources
Domains NLP, Training, Deep Learning
Last Updated 2026-02-13 00:00 GMT

Overview

Concrete tool for executing the full model training loop including forward pass, backpropagation, optimization, logging, and checkpointing, provided by the HuggingFace Transformers library.

Description

Trainer.train() is the main training entry point that orchestrates the complete training process. It handles checkpoint resumption, model re-initialization (for hyperparameter search), and delegates to an inner training loop that performs the actual gradient computation and parameter updates.

The method supports:

  • Checkpoint resumption -- Restoring model, optimizer, scheduler, and RNG states from a previous checkpoint.
  • Automatic batch size finding -- Exponentially reducing batch size on OOM errors.
  • Hyperparameter search integration -- Accepting trial objects from Optuna or Ray Tune.
  • Hub synchronization -- Automatically pushing checkpoints to the HuggingFace Hub during training.

The inner training loop (_inner_training_loop) manages the per-step logic: iterating over batches, computing loss, accumulating gradients, clipping gradients, stepping the optimizer and scheduler, logging metrics, running evaluation, and saving checkpoints.

Usage

Call trainer.train() after all preceding setup steps are complete (data loading, tokenization, model loading, configuration, Trainer initialization). Optionally pass a checkpoint path to resume from.

Code Reference

Source Location

  • Repository: transformers
  • File: src/transformers/trainer.py (lines 1302-1408 for train(); lines 1410+ for _inner_training_loop)

Signature

def train(
    self,
    resume_from_checkpoint: str | bool | None = None,
    trial: "optuna.Trial | dict[str, Any] | None" = None,
    ignore_keys_for_eval: list[str] | None = None,
) -> TrainOutput:

Import

from transformers import Trainer
# train() is an instance method called on a Trainer object

I/O Contract

Inputs

Name Type Required Description
resume_from_checkpoint str or bool No Path to a checkpoint directory to resume from. If True, resumes from the last checkpoint in output_dir. If None or False, trains from scratch
trial optuna.Trial or dict No Trial object for hyperparameter search (Optuna) or hyperparameter dictionary (Ray Tune)
ignore_keys_for_eval list[str] No Keys in the model output dictionary to ignore when gathering predictions for evaluation during training

Outputs

Name Type Description
train_output TrainOutput Object containing global_step (total optimization steps completed), training_loss (average training loss), and metrics (dict with training speed and memory metrics)

Usage Examples

Basic Usage

from transformers import Trainer, TrainingArguments

trainer = Trainer(
    model=model,
    args=TrainingArguments(output_dir="./results", num_train_epochs=3),
    train_dataset=train_dataset,
)

result = trainer.train()
print(f"Training loss: {result.training_loss:.4f}")
print(f"Global steps: {result.global_step}")

Resuming from Checkpoint

# Resume from a specific checkpoint
result = trainer.train(resume_from_checkpoint="./results/checkpoint-500")

# Resume from the last checkpoint in output_dir
result = trainer.train(resume_from_checkpoint=True)

Full Training Pipeline

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)
from datasets import load_dataset

# Setup
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("gpt2")

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

def tokenize_fn(examples):
    return tokenizer(examples["text"], truncation=True, max_length=512)

tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])

args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    logging_steps=100,
    save_strategy="epoch",
    bf16=True,
)

trainer = Trainer(model=model, args=args, train_dataset=tokenized)
result = trainer.train()

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment