Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Peft Bi Encoder Cosine Loss

From Leeroopedia


Metadata

Overview

This implementation documents the bi-encoder training pattern for semantic search using cosine similarity loss, as demonstrated in the PEFT feature extraction example. The pattern comprises four key components: (1) a custom AutoModelForSentenceEmbedding wrapper that performs mean pooling and L2 normalization, (2) a cosine similarity computation function, (3) a cosine margin loss function, and (4) LoRA adapter configuration with TaskType.FEATURE_EXTRACTION.

Imports

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, default_data_collator
from datasets import load_dataset, DatasetDict
from peft import LoraConfig, TaskType, get_peft_model
from accelerate import Accelerator
import evaluate

Core Components

AutoModelForSentenceEmbedding

A wrapper module that transforms a pretrained language model into a sentence embedding model using mean pooling and optional L2 normalization.

class AutoModelForSentenceEmbedding(nn.Module):
    def __init__(self, model_name, tokenizer, normalize=True):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.normalize = normalize
        self.tokenizer = tokenizer

    def forward(self, **kwargs):
        model_output = self.model(**kwargs)
        embeddings = self.mean_pooling(model_output, kwargs["attention_mask"])
        if self.normalize:
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        return embeddings

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]  # all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
            input_mask_expanded.sum(1), min=1e-9
        )

Key design points:

  • model_output[0] extracts the last hidden state (all token embeddings)
  • mean_pooling computes an attention-mask-weighted mean, ensuring padding tokens contribute zero weight
  • torch.clamp(..., min=1e-9) prevents division by zero for fully-padded sequences
  • L2 normalization (torch.nn.functional.normalize) projects embeddings onto the unit sphere
  • The class delegates unknown attribute access to the wrapped model via __getattr__, enabling PEFT compatibility

Cosine Similarity Computation

def get_cosing_embeddings(query_embs, product_embs):
    return torch.sum(query_embs * product_embs, axis=1)

Since both embeddings are L2-normalized, element-wise multiplication followed by summation along the feature axis yields the cosine similarity. The result is a 1D tensor of per-pair similarity scores.

Cosine Margin Loss

def get_loss(cosine_score, labels):
    return torch.mean(
        torch.square(labels * (1 - cosine_score))
        + torch.clamp((1 - labels) * cosine_score, min=0.0)
    )

Loss behavior:

  • Positive pairs (labels=1): Penalizes (1 - cosine_score)^2, pushing similarity toward 1.0
  • Negative pairs (labels=0): Penalizes max(cosine_score, 0), pushing similarity toward 0 or below. The clamp(min=0.0) ensures that negative similarities (already well-separated) incur no loss.

LoRA Configuration

peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION,
    target_modules=["key", "query", "value"],
)

model = AutoModelForSentenceEmbedding(model_name, tokenizer)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

Key parameters:

  • task_type=TaskType.FEATURE_EXTRACTION: Indicates embedding extraction mode; no classification or generation head is added
  • target_modules=["key", "query", "value"]: Injects LoRA adapters into the self-attention key, query, and value projections
  • r=8: Low rank is sufficient for adapting embedding representations
  • lora_alpha=16: Scaling factor (effective scaling = alpha/r = 2.0)

Dataset Preparation

dataset = load_dataset(args.dataset_name, revision="main")
# or for sanity testing:
# train_dataset = load_dataset("smangrul/amazon_esci", split="train[:1024]")

def preprocess_function(examples):
    # Tokenize queries
    queries = examples["query"]
    result = tokenizer(queries, padding="max_length", max_length=70, truncation=True)
    result = {f"query_{k}": v for k, v in result.items()}

    # Tokenize products (documents)
    products = examples["product_title"]
    result_products = tokenizer(products, padding="max_length", max_length=70, truncation=True)
    for k, v in result_products.items():
        result[f"product_{k}"] = v

    # Relevance labels
    result["labels"] = examples["relevance_label"]
    return result

processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
    desc="Running tokenizer on dataset",
)

Key details:

  • Queries and products are tokenized separately with prefixed keys (query_input_ids, product_input_ids, etc.)
  • Both are padded/truncated to max_length=70
  • relevance_label provides the supervision signal (binary or graded relevance)
  • The dataset used is smangrul/amazon_esci (Amazon product search relevance)

Training Loop

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=args.num_warmup_steps,
    num_training_steps=args.max_train_steps,
)

model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

for epoch in range(args.num_train_epochs):
    model.train()
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(model):
            # Encode queries and products separately
            query_embs = model(
                **{k.replace("query_", ""): v for k, v in batch.items() if "query" in k}
            )
            product_embs = model(
                **{k.replace("product_", ""): v for k, v in batch.items() if "product" in k}
            )

            # Compute loss
            loss = get_loss(
                get_cosing_embeddings(query_embs, product_embs),
                batch["labels"],
            )

            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            model.zero_grad()

Key design points:

  • Query and product embeddings are computed with separate forward passes through the same model
  • Batch key prefixes (query_, product_) are stripped before passing to the model
  • Gradient accumulation is handled by accelerator.accumulate(model)
  • The training loop uses the Accelerate library for distributed training support

Evaluation

metric = evaluate.load("roc_auc")

model.eval()
for step, batch in enumerate(eval_dataloader):
    with torch.no_grad():
        query_embs = model(**{k.replace("query_", ""): v for k, v in batch.items() if "query" in k})
        product_embs = model(**{k.replace("product_", ""): v for k, v in batch.items() if "product" in k})

    prediction_scores = get_cosing_embeddings(query_embs, product_embs)
    prediction_scores, references = accelerator.gather_for_metrics((prediction_scores, batch["labels"]))
    metric.add_batch(prediction=prediction_scores, reference=references)

eval_metric = metric.compute()
# Reports ROC-AUC score

The evaluation uses ROC-AUC (Area Under the Receiver Operating Characteristic Curve) as the primary metric, measuring the model's ability to rank relevant items above irrelevant ones.

Checkpoint Management

# Custom save/load hooks for PEFT compatibility with Accelerate
def save_model_hook(models, weights, output_dir):
    for model in models:
        sub_dir = "model"
        model.save_pretrained(output_dir, sub_dir)
        weights.pop()

def load_model_hook(models, input_dir):
    while len(models) > 0:
        model = models.pop()
        if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
            model.load_adapter(input_dir, model.active_adapter, is_trainable=True)

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

Custom hooks are registered with the Accelerator to ensure PEFT adapter weights are correctly saved and loaded during checkpointing.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment