Implementation:Huggingface Peft Bi Encoder Cosine Loss
Metadata
- Source: examples/feature_extraction/peft_lora_embedding_semantic_search.py:L169-454
- Repository: huggingface/peft
- Type: Pattern Doc
- Domains: NLP, Semantic_Search
Overview
This implementation documents the bi-encoder training pattern for semantic search using cosine similarity loss, as demonstrated in the PEFT feature extraction example. The pattern comprises four key components: (1) a custom AutoModelForSentenceEmbedding wrapper that performs mean pooling and L2 normalization, (2) a cosine similarity computation function, (3) a cosine margin loss function, and (4) LoRA adapter configuration with TaskType.FEATURE_EXTRACTION.
Imports
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, default_data_collator
from datasets import load_dataset, DatasetDict
from peft import LoraConfig, TaskType, get_peft_model
from accelerate import Accelerator
import evaluate
Core Components
AutoModelForSentenceEmbedding
A wrapper module that transforms a pretrained language model into a sentence embedding model using mean pooling and optional L2 normalization.
class AutoModelForSentenceEmbedding(nn.Module):
def __init__(self, model_name, tokenizer, normalize=True):
super().__init__()
self.model = AutoModel.from_pretrained(model_name)
self.normalize = normalize
self.tokenizer = tokenizer
def forward(self, **kwargs):
model_output = self.model(**kwargs)
embeddings = self.mean_pooling(model_output, kwargs["attention_mask"])
if self.normalize:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0] # all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
Key design points:
model_output[0]extracts the last hidden state (all token embeddings)mean_poolingcomputes an attention-mask-weighted mean, ensuring padding tokens contribute zero weighttorch.clamp(..., min=1e-9)prevents division by zero for fully-padded sequences- L2 normalization (
torch.nn.functional.normalize) projects embeddings onto the unit sphere - The class delegates unknown attribute access to the wrapped model via
__getattr__, enabling PEFT compatibility
Cosine Similarity Computation
def get_cosing_embeddings(query_embs, product_embs):
return torch.sum(query_embs * product_embs, axis=1)
Since both embeddings are L2-normalized, element-wise multiplication followed by summation along the feature axis yields the cosine similarity. The result is a 1D tensor of per-pair similarity scores.
Cosine Margin Loss
def get_loss(cosine_score, labels):
return torch.mean(
torch.square(labels * (1 - cosine_score))
+ torch.clamp((1 - labels) * cosine_score, min=0.0)
)
Loss behavior:
- Positive pairs (
labels=1): Penalizes(1 - cosine_score)^2, pushing similarity toward 1.0 - Negative pairs (
labels=0): Penalizesmax(cosine_score, 0), pushing similarity toward 0 or below. Theclamp(min=0.0)ensures that negative similarities (already well-separated) incur no loss.
LoRA Configuration
peft_config = LoraConfig(
r=8,
lora_alpha=16,
bias="none",
task_type=TaskType.FEATURE_EXTRACTION,
target_modules=["key", "query", "value"],
)
model = AutoModelForSentenceEmbedding(model_name, tokenizer)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
Key parameters:
task_type=TaskType.FEATURE_EXTRACTION: Indicates embedding extraction mode; no classification or generation head is addedtarget_modules=["key", "query", "value"]: Injects LoRA adapters into the self-attention key, query, and value projectionsr=8: Low rank is sufficient for adapting embedding representationslora_alpha=16: Scaling factor (effective scaling = alpha/r = 2.0)
Dataset Preparation
dataset = load_dataset(args.dataset_name, revision="main")
# or for sanity testing:
# train_dataset = load_dataset("smangrul/amazon_esci", split="train[:1024]")
def preprocess_function(examples):
# Tokenize queries
queries = examples["query"]
result = tokenizer(queries, padding="max_length", max_length=70, truncation=True)
result = {f"query_{k}": v for k, v in result.items()}
# Tokenize products (documents)
products = examples["product_title"]
result_products = tokenizer(products, padding="max_length", max_length=70, truncation=True)
for k, v in result_products.items():
result[f"product_{k}"] = v
# Relevance labels
result["labels"] = examples["relevance_label"]
return result
processed_datasets = dataset.map(
preprocess_function,
batched=True,
remove_columns=dataset["train"].column_names,
desc="Running tokenizer on dataset",
)
Key details:
- Queries and products are tokenized separately with prefixed keys (
query_input_ids,product_input_ids, etc.) - Both are padded/truncated to
max_length=70 relevance_labelprovides the supervision signal (binary or graded relevance)- The dataset used is smangrul/amazon_esci (Amazon product search relevance)
Training Loop
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
for epoch in range(args.num_train_epochs):
model.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(model):
# Encode queries and products separately
query_embs = model(
**{k.replace("query_", ""): v for k, v in batch.items() if "query" in k}
)
product_embs = model(
**{k.replace("product_", ""): v for k, v in batch.items() if "product" in k}
)
# Compute loss
loss = get_loss(
get_cosing_embeddings(query_embs, product_embs),
batch["labels"],
)
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
model.zero_grad()
Key design points:
- Query and product embeddings are computed with separate forward passes through the same model
- Batch key prefixes (
query_,product_) are stripped before passing to the model - Gradient accumulation is handled by
accelerator.accumulate(model) - The training loop uses the Accelerate library for distributed training support
Evaluation
metric = evaluate.load("roc_auc")
model.eval()
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
query_embs = model(**{k.replace("query_", ""): v for k, v in batch.items() if "query" in k})
product_embs = model(**{k.replace("product_", ""): v for k, v in batch.items() if "product" in k})
prediction_scores = get_cosing_embeddings(query_embs, product_embs)
prediction_scores, references = accelerator.gather_for_metrics((prediction_scores, batch["labels"]))
metric.add_batch(prediction=prediction_scores, reference=references)
eval_metric = metric.compute()
# Reports ROC-AUC score
The evaluation uses ROC-AUC (Area Under the Receiver Operating Characteristic Curve) as the primary metric, measuring the model's ability to rank relevant items above irrelevant ones.
Checkpoint Management
# Custom save/load hooks for PEFT compatibility with Accelerate
def save_model_hook(models, weights, output_dir):
for model in models:
sub_dir = "model"
model.save_pretrained(output_dir, sub_dir)
weights.pop()
def load_model_hook(models, input_dir):
while len(models) > 0:
model = models.pop()
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
model.load_adapter(input_dir, model.active_adapter, is_trainable=True)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
Custom hooks are registered with the Accelerator to ensure PEFT adapter weights are correctly saved and loaded during checkpointing.