Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Neuml Txtai HFTrainer Config

From Leeroopedia
Revision as of 16:04, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Neuml_Txtai_HFTrainer_Config.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Overview

This page documents the configuration methods of the HFTrainer class responsible for parsing training arguments, setting up quantization, and configuring LoRA adapters for parameter-efficient fine-tuning. These methods are called internally during the training workflow to transform user-provided configuration into the objects required by the Hugging Face Trainer.

API

HFTrainer.parse

def parse(self, updates)

Parses and merges custom training arguments with defaults. Creates a TrainingArguments instance that controls all aspects of the training loop.

Parameters:

Name Type Description
updates dict Custom training arguments to merge with defaults. Supports any valid Hugging Face TrainingArguments parameter.

Returns: TrainingArguments -- A txtai-extended TrainingArguments instance with merged configuration.

Default arguments applied before merging:

Key Default Value Description
output_dir "" Empty string disables model saving
save_strategy "no" Disables intermediate checkpoint saving
report_to "none" Disables experiment tracking integrations
log_level "warning" Reduces console output to warnings only
use_cpu auto-detected True if no GPU/accelerator is available

Example:

from txtai.pipeline import HFTrainer

trainer = HFTrainer()

# Parse with custom overrides
args = trainer.parse({
    "num_train_epochs": 3,
    "per_device_train_batch_size": 16,
    "learning_rate": 2e-5,
    "output_dir": "models/my-model",
    "save_strategy": "epoch"
})

Note: The returned TrainingArguments is a txtai subclass that overrides should_save to return False when output_dir is empty or falsy. This enables fully in-memory training without disk writes.

HFTrainer.quantization

def quantization(self, quantize)

Formats and returns a quantization configuration for loading the base model in reduced precision.

Parameters:

Name Type Description
quantize bool, dict, BitsAndBytesConfig, or None Quantization configuration. True applies defaults; a dict is passed to BitsAndBytesConfig; a BitsAndBytesConfig is used directly.

Returns: BitsAndBytesConfig or None -- The formatted quantization configuration, or None if quantization is not requested.

Default quantization settings when quantize=True:

Parameter Value Description
load_in_4bit True Loads model weights in 4-bit precision
bnb_4bit_use_double_quant True Quantizes the quantization constants for additional memory savings
bnb_4bit_quant_type "nf4" Uses Normal Float 4-bit data type optimized for normally-distributed weights
bnb_4bit_compute_dtype "bfloat16" Computation dtype for matrix multiplications

Example:

from txtai.pipeline import HFTrainer

trainer = HFTrainer()

# Use default 4-bit quantization
config = trainer.quantization(True)

# Use custom quantization settings
config = trainer.quantization({
    "load_in_4bit": True,
    "bnb_4bit_quant_type": "fp4",
    "bnb_4bit_compute_dtype": "float16"
})

# No quantization
config = trainer.quantization(None)  # Returns None

Important: Quantization requires a CUDA-compatible GPU. In the HFTrainer.model() method, the quantization configuration is cleared (set to None) if torch.cuda.is_available() returns False. This ensures the code runs on CPU without errors, albeit without quantization.

HFTrainer.lora

def lora(self, task, lora)

Formats and returns a LoRA configuration for parameter-efficient fine-tuning.

Parameters:

Name Type Description
task str Model task or category (e.g., "text-classification", "language-generation"). Used to determine the LoRA task type.
lora bool, dict, LoraConfig, or None LoRA configuration. True applies defaults; a dict is passed to LoraConfig; a LoraConfig is used directly.

Returns: LoraConfig or None -- The formatted LoRA configuration, or the input value if it is already a LoraConfig or falsy.

Default LoRA settings when lora=True:

Parameter Value Description
r 16 Rank of the low-rank decomposition matrices
lora_alpha 8 Scaling factor controlling adapter contribution magnitude
target_modules "all-linear" Applies LoRA adapters to all linear layers
lora_dropout 0.05 Dropout rate applied to LoRA layers for regularization
bias "none" Bias parameters are not trained

Task type auto-detection: If the task_type key is not present in a dictionary configuration, the method automatically determines the correct LoRA task type via self.loratask(task).

Example:

from txtai.pipeline import HFTrainer

trainer = HFTrainer()

# Use default LoRA settings for classification
config = trainer.lora("text-classification", True)

# Custom LoRA configuration
config = trainer.lora("language-generation", {
    "r": 32,
    "lora_alpha": 16,
    "target_modules": ["q_proj", "v_proj"],
    "lora_dropout": 0.1,
    "bias": "none"
})

HFTrainer.peft

def peft(self, task, lora, model)

Wraps the input model as a PEFT model if LoRA configuration is set. This method is called after the base model is loaded and applies the LoRA adapters.

Parameters:

Name Type Description
task str Model task or category
lora bool, dict, LoraConfig, or None LoRA configuration
model PreTrainedModel The base transformer model to wrap

Returns: The wrapped PEFT model if LoRA is configured, otherwise returns the input model unchanged.

Behavior when LoRA is enabled:

  1. Formats the LoRA configuration via self.lora(task, lora).
  2. Calls prepare_model_for_kbit_training(model) to prepare quantized models for training (handles gradient checkpointing and layer norm casting).
  3. Calls get_peft_model(model, config) to inject LoRA adapter layers and freeze original parameters.
  4. Calls model.print_trainable_parameters() to display the number of trainable vs total parameters.

HFTrainer.loratask

def loratask(self, task)

Looks up the corresponding LoRA TaskType for a given training task string.

Parameters:

Name Type Description
task str Model task or category

Returns: TaskType -- The PEFT library task type enum value.

Task mapping:

Input Task LoRA TaskType
language-generation TaskType.CAUSAL_LM
language-modeling TaskType.FEATURE_EXTRACTION
question-answering TaskType.QUESTION_ANS
sequence-sequence TaskType.SEQ_2_SEQ_LM
text-classification TaskType.SEQ_CLS
token-detection TaskType.FEATURE_EXTRACTION

If the input task is not recognized, it defaults to "text-classification" (i.e., TaskType.SEQ_CLS).

Source

  • src/python/txtai/pipeline/train/hftrainer.py (lines 146-354)

Import

from txtai.pipeline import HFTrainer

Dependencies

Quantization and LoRA require optional dependencies:

# Required for quantization and LoRA
pip install peft bitsandbytes
# Or install via txtai extras
pip install txtai[pipeline]

If peft is not installed and quantize or lora is passed, an ImportError is raised with the message: 'PEFT is not available - install "pipeline" extra to enable'.

See Also

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment