Implementation:Huggingface Transformers Trainer Train
| Knowledge Sources | |
|---|---|
| Domains | NLP, Training, Deep Learning |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Concrete tool for executing the full model training loop including forward pass, backpropagation, optimization, logging, and checkpointing, provided by the HuggingFace Transformers library.
Description
Trainer.train() is the main training entry point that orchestrates the complete training process. It handles checkpoint resumption, model re-initialization (for hyperparameter search), and delegates to an inner training loop that performs the actual gradient computation and parameter updates.
The method supports:
- Checkpoint resumption -- Restoring model, optimizer, scheduler, and RNG states from a previous checkpoint.
- Automatic batch size finding -- Exponentially reducing batch size on OOM errors.
- Hyperparameter search integration -- Accepting trial objects from Optuna or Ray Tune.
- Hub synchronization -- Automatically pushing checkpoints to the HuggingFace Hub during training.
The inner training loop (_inner_training_loop) manages the per-step logic: iterating over batches, computing loss, accumulating gradients, clipping gradients, stepping the optimizer and scheduler, logging metrics, running evaluation, and saving checkpoints.
Usage
Call trainer.train() after all preceding setup steps are complete (data loading, tokenization, model loading, configuration, Trainer initialization). Optionally pass a checkpoint path to resume from.
Code Reference
Source Location
- Repository: transformers
- File: src/transformers/trainer.py (lines 1302-1408 for train(); lines 1410+ for _inner_training_loop)
Signature
def train(
self,
resume_from_checkpoint: str | bool | None = None,
trial: "optuna.Trial | dict[str, Any] | None" = None,
ignore_keys_for_eval: list[str] | None = None,
) -> TrainOutput:
Import
from transformers import Trainer
# train() is an instance method called on a Trainer object
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| resume_from_checkpoint | str or bool | No | Path to a checkpoint directory to resume from. If True, resumes from the last checkpoint in output_dir. If None or False, trains from scratch |
| trial | optuna.Trial or dict | No | Trial object for hyperparameter search (Optuna) or hyperparameter dictionary (Ray Tune) |
| ignore_keys_for_eval | list[str] | No | Keys in the model output dictionary to ignore when gathering predictions for evaluation during training |
Outputs
| Name | Type | Description |
|---|---|---|
| train_output | TrainOutput | Object containing global_step (total optimization steps completed), training_loss (average training loss), and metrics (dict with training speed and memory metrics) |
Usage Examples
Basic Usage
from transformers import Trainer, TrainingArguments
trainer = Trainer(
model=model,
args=TrainingArguments(output_dir="./results", num_train_epochs=3),
train_dataset=train_dataset,
)
result = trainer.train()
print(f"Training loss: {result.training_loss:.4f}")
print(f"Global steps: {result.global_step}")
Resuming from Checkpoint
# Resume from a specific checkpoint
result = trainer.train(resume_from_checkpoint="./results/checkpoint-500")
# Resume from the last checkpoint in output_dir
result = trainer.train(resume_from_checkpoint=True)
Full Training Pipeline
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
)
from datasets import load_dataset
# Setup
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("gpt2")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
def tokenize_fn(examples):
return tokenizer(examples["text"], truncation=True, max_length=512)
tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
args = TrainingArguments(
output_dir="./results",
num_train_epochs=1,
per_device_train_batch_size=4,
logging_steps=100,
save_strategy="epoch",
bf16=True,
)
trainer = Trainer(model=model, args=args, train_dataset=tokenized)
result = trainer.train()