Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Neuml Txtai HFTrainer Config

From Leeroopedia


Overview

This page documents the configuration methods of the HFTrainer class responsible for parsing training arguments, setting up quantization, and configuring LoRA adapters for parameter-efficient fine-tuning. These methods are called internally during the training workflow to transform user-provided configuration into the objects required by the Hugging Face Trainer.

API

HFTrainer.parse

def parse(self, updates)

Parses and merges custom training arguments with defaults. Creates a TrainingArguments instance that controls all aspects of the training loop.

Parameters:

Name Type Description
updates dict Custom training arguments to merge with defaults. Supports any valid Hugging Face TrainingArguments parameter.

Returns: TrainingArguments -- A txtai-extended TrainingArguments instance with merged configuration.

Default arguments applied before merging:

Key Default Value Description
output_dir "" Empty string disables model saving
save_strategy "no" Disables intermediate checkpoint saving
report_to "none" Disables experiment tracking integrations
log_level "warning" Reduces console output to warnings only
use_cpu auto-detected True if no GPU/accelerator is available

Example:

from txtai.pipeline import HFTrainer

trainer = HFTrainer()

# Parse with custom overrides
args = trainer.parse({
    "num_train_epochs": 3,
    "per_device_train_batch_size": 16,
    "learning_rate": 2e-5,
    "output_dir": "models/my-model",
    "save_strategy": "epoch"
})

Note: The returned TrainingArguments is a txtai subclass that overrides should_save to return False when output_dir is empty or falsy. This enables fully in-memory training without disk writes.

HFTrainer.quantization

def quantization(self, quantize)

Formats and returns a quantization configuration for loading the base model in reduced precision.

Parameters:

Name Type Description
quantize bool, dict, BitsAndBytesConfig, or None Quantization configuration. True applies defaults; a dict is passed to BitsAndBytesConfig; a BitsAndBytesConfig is used directly.

Returns: BitsAndBytesConfig or None -- The formatted quantization configuration, or None if quantization is not requested.

Default quantization settings when quantize=True:

Parameter Value Description
load_in_4bit True Loads model weights in 4-bit precision
bnb_4bit_use_double_quant True Quantizes the quantization constants for additional memory savings
bnb_4bit_quant_type "nf4" Uses Normal Float 4-bit data type optimized for normally-distributed weights
bnb_4bit_compute_dtype "bfloat16" Computation dtype for matrix multiplications

Example:

from txtai.pipeline import HFTrainer

trainer = HFTrainer()

# Use default 4-bit quantization
config = trainer.quantization(True)

# Use custom quantization settings
config = trainer.quantization({
    "load_in_4bit": True,
    "bnb_4bit_quant_type": "fp4",
    "bnb_4bit_compute_dtype": "float16"
})

# No quantization
config = trainer.quantization(None)  # Returns None

Important: Quantization requires a CUDA-compatible GPU. In the HFTrainer.model() method, the quantization configuration is cleared (set to None) if torch.cuda.is_available() returns False. This ensures the code runs on CPU without errors, albeit without quantization.

HFTrainer.lora

def lora(self, task, lora)

Formats and returns a LoRA configuration for parameter-efficient fine-tuning.

Parameters:

Name Type Description
task str Model task or category (e.g., "text-classification", "language-generation"). Used to determine the LoRA task type.
lora bool, dict, LoraConfig, or None LoRA configuration. True applies defaults; a dict is passed to LoraConfig; a LoraConfig is used directly.

Returns: LoraConfig or None -- The formatted LoRA configuration, or the input value if it is already a LoraConfig or falsy.

Default LoRA settings when lora=True:

Parameter Value Description
r 16 Rank of the low-rank decomposition matrices
lora_alpha 8 Scaling factor controlling adapter contribution magnitude
target_modules "all-linear" Applies LoRA adapters to all linear layers
lora_dropout 0.05 Dropout rate applied to LoRA layers for regularization
bias "none" Bias parameters are not trained

Task type auto-detection: If the task_type key is not present in a dictionary configuration, the method automatically determines the correct LoRA task type via self.loratask(task).

Example:

from txtai.pipeline import HFTrainer

trainer = HFTrainer()

# Use default LoRA settings for classification
config = trainer.lora("text-classification", True)

# Custom LoRA configuration
config = trainer.lora("language-generation", {
    "r": 32,
    "lora_alpha": 16,
    "target_modules": ["q_proj", "v_proj"],
    "lora_dropout": 0.1,
    "bias": "none"
})

HFTrainer.peft

def peft(self, task, lora, model)

Wraps the input model as a PEFT model if LoRA configuration is set. This method is called after the base model is loaded and applies the LoRA adapters.

Parameters:

Name Type Description
task str Model task or category
lora bool, dict, LoraConfig, or None LoRA configuration
model PreTrainedModel The base transformer model to wrap

Returns: The wrapped PEFT model if LoRA is configured, otherwise returns the input model unchanged.

Behavior when LoRA is enabled:

  1. Formats the LoRA configuration via self.lora(task, lora).
  2. Calls prepare_model_for_kbit_training(model) to prepare quantized models for training (handles gradient checkpointing and layer norm casting).
  3. Calls get_peft_model(model, config) to inject LoRA adapter layers and freeze original parameters.
  4. Calls model.print_trainable_parameters() to display the number of trainable vs total parameters.

HFTrainer.loratask

def loratask(self, task)

Looks up the corresponding LoRA TaskType for a given training task string.

Parameters:

Name Type Description
task str Model task or category

Returns: TaskType -- The PEFT library task type enum value.

Task mapping:

Input Task LoRA TaskType
language-generation TaskType.CAUSAL_LM
language-modeling TaskType.FEATURE_EXTRACTION
question-answering TaskType.QUESTION_ANS
sequence-sequence TaskType.SEQ_2_SEQ_LM
text-classification TaskType.SEQ_CLS
token-detection TaskType.FEATURE_EXTRACTION

If the input task is not recognized, it defaults to "text-classification" (i.e., TaskType.SEQ_CLS).

Source

  • src/python/txtai/pipeline/train/hftrainer.py (lines 146-354)

Import

from txtai.pipeline import HFTrainer

Dependencies

Quantization and LoRA require optional dependencies:

# Required for quantization and LoRA
pip install peft bitsandbytes
# Or install via txtai extras
pip install txtai[pipeline]

If peft is not installed and quantize or lora is passed, an ImportError is raised with the message: 'PEFT is not available - install "pipeline" extra to enable'.

See Also

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment