Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bigscience workshop Petals Distributed Evaluation Loop

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, Evaluation, NLP
Last Updated 2026-02-09 14:00 GMT

Overview

A user-defined evaluation pattern for measuring classification performance of prompt-tuned distributed Petals models using standard PyTorch evaluation conventions.

Description

This is a Pattern Doc — it documents the interface and conventions that users must follow for evaluating distributed models, rather than a specific library API.

Key components:

  • torch.no_grad(): Disables gradient computation for memory efficiency and speed
  • model(**batch): Forward pass through the distributed model (uses RemoteSequential.forward without autograd)
  • torch.argmax(): Converts logits to class predictions
  • Metric computation: User-defined metrics (accuracy, F1, precision, recall)

The forward pass during evaluation routes through the same remote servers as training, but without the _RemoteSequentialAutogradFunction — instead, it uses the simpler forward path through RemoteSequential.

Usage

Implement this pattern after each training epoch or at the end of training. The pattern works with any classification task using DistributedLlamaForSequenceClassification or similar models.

Code Reference

Source Location

  • Repository: petals
  • File: Pattern based on src/petals/client/remote_sequential.py (L52-58, RemoteSequential.forward)
  • File: src/petals/models/llama/model.py (L156-174, DistributedLlamaForSequenceClassification)

Interface Specification

def evaluate(
    model,                    # DistributedLlamaForSequenceClassification
    eval_dataloader,          # DataLoader yielding batches with input_ids, attention_mask, labels
    device: str = "cpu",      # Device for local operations
) -> Dict[str, float]:
    """
    Evaluate a prompt-tuned distributed model.

    Returns:
        Dict with metric names and values (e.g. {"accuracy": 0.92, "f1": 0.91})
    """
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in eval_dataloader:
            outputs = model(**batch)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(batch["labels"].cpu().tolist())

    accuracy = sum(p == l for p, l in zip(all_preds, all_labels)) / len(all_labels)
    return {"accuracy": accuracy}

Import

import torch
# No specific import needed — this is a user-defined pattern

I/O Contract

Inputs

Name Type Required Description
model DistributedLlamaForSequenceClassification Yes Prompt-tuned distributed model
eval_dataloader DataLoader Yes Validation/test data batches with input_ids, attention_mask, labels

Outputs

Name Type Description
metrics Dict[str, float] Evaluation metrics (accuracy, F1, loss, etc.)
predictions List[int] Predicted class labels for all examples

Usage Examples

Full Evaluation After Training

import torch
from sklearn.metrics import accuracy_score, f1_score

def evaluate(model, eval_dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0

    with torch.no_grad():
        for batch in eval_dataloader:
            outputs = model(**batch)
            total_loss += outputs.loss.item()

            preds = torch.argmax(outputs.logits, dim=-1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(batch["labels"].cpu().tolist())

    avg_loss = total_loss / len(eval_dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average="weighted")

    return {
        "loss": avg_loss,
        "accuracy": accuracy,
        "f1": f1,
    }

# After training
metrics = evaluate(model, eval_dataloader)
print(f"Accuracy: {metrics['accuracy']:.4f}, F1: {metrics['f1']:.4f}")

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment