Implementation:Bigscience workshop Petals Distributed Evaluation Loop
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Evaluation, NLP |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
A user-defined evaluation pattern for measuring classification performance of prompt-tuned distributed Petals models using standard PyTorch evaluation conventions.
Description
This is a Pattern Doc — it documents the interface and conventions that users must follow for evaluating distributed models, rather than a specific library API.
Key components:
- torch.no_grad(): Disables gradient computation for memory efficiency and speed
- model(**batch): Forward pass through the distributed model (uses RemoteSequential.forward without autograd)
- torch.argmax(): Converts logits to class predictions
- Metric computation: User-defined metrics (accuracy, F1, precision, recall)
The forward pass during evaluation routes through the same remote servers as training, but without the _RemoteSequentialAutogradFunction — instead, it uses the simpler forward path through RemoteSequential.
Usage
Implement this pattern after each training epoch or at the end of training. The pattern works with any classification task using DistributedLlamaForSequenceClassification or similar models.
Code Reference
Source Location
- Repository: petals
- File: Pattern based on src/petals/client/remote_sequential.py (L52-58, RemoteSequential.forward)
- File: src/petals/models/llama/model.py (L156-174, DistributedLlamaForSequenceClassification)
Interface Specification
def evaluate(
model, # DistributedLlamaForSequenceClassification
eval_dataloader, # DataLoader yielding batches with input_ids, attention_mask, labels
device: str = "cpu", # Device for local operations
) -> Dict[str, float]:
"""
Evaluate a prompt-tuned distributed model.
Returns:
Dict with metric names and values (e.g. {"accuracy": 0.92, "f1": 0.91})
"""
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for batch in eval_dataloader:
outputs = model(**batch)
logits = outputs.logits
preds = torch.argmax(logits, dim=-1)
all_preds.extend(preds.cpu().tolist())
all_labels.extend(batch["labels"].cpu().tolist())
accuracy = sum(p == l for p, l in zip(all_preds, all_labels)) / len(all_labels)
return {"accuracy": accuracy}
Import
import torch
# No specific import needed — this is a user-defined pattern
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | DistributedLlamaForSequenceClassification | Yes | Prompt-tuned distributed model |
| eval_dataloader | DataLoader | Yes | Validation/test data batches with input_ids, attention_mask, labels |
Outputs
| Name | Type | Description |
|---|---|---|
| metrics | Dict[str, float] | Evaluation metrics (accuracy, F1, loss, etc.) |
| predictions | List[int] | Predicted class labels for all examples |
Usage Examples
Full Evaluation After Training
import torch
from sklearn.metrics import accuracy_score, f1_score
def evaluate(model, eval_dataloader):
model.eval()
all_preds = []
all_labels = []
total_loss = 0
with torch.no_grad():
for batch in eval_dataloader:
outputs = model(**batch)
total_loss += outputs.loss.item()
preds = torch.argmax(outputs.logits, dim=-1)
all_preds.extend(preds.cpu().tolist())
all_labels.extend(batch["labels"].cpu().tolist())
avg_loss = total_loss / len(eval_dataloader)
accuracy = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average="weighted")
return {
"loss": avg_loss,
"accuracy": accuracy,
"f1": f1,
}
# After training
metrics = evaluate(model, eval_dataloader)
print(f"Accuracy: {metrics['accuracy']:.4f}, F1: {metrics['f1']:.4f}")