Implementation:CarperAI Trlx Metric Function Interface
| Knowledge Sources | |
|---|---|
| Domains | Evaluation, NLP |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Interface specification for user-defined evaluation metric functions in trlx training.
Description
The metric function is a user-defined callable that computes evaluation statistics on batches of generated text during periodic evaluation. This is a Pattern Doc — it documents the interface users must implement. The function receives generated samples, prompts, and outputs, and returns a dictionary mapping metric names to lists of per-sample scores. trlx calls this function in AccelerateRLTrainer.evaluate() at configurable intervals.
Usage
Implement this interface when using trlx.train() for any training method (PPO, ILQL, SFT) and you want to monitor generation quality. Pass as the metric_fn argument. Particularly important for ILQL and SFT where no live reward function is provided.
Code Reference
Source Location
- Repository: trlx
- File: examples/ilql_sentiments.py
- Lines: L31-33 (canonical ILQL example)
Interface Specification
def metric_fn(
samples: List[str], # Full generated text (prompt + output)
prompts: List[str], # Original prompt text
outputs: List[str], # Generated output only
**kwargs # Additional metadata
) -> Dict[str, List[float]]:
"""
Compute evaluation metrics on generated text.
Args:
samples: Complete generated strings.
prompts: The original prompt strings.
outputs: The generated completion strings only.
Returns:
Dict mapping metric names to lists of per-sample float scores.
Example: {"sentiment": [0.95, 0.87, ...], "length": [42, 38, ...]}
"""
...
Import
# No import needed — this is a user-defined function
import trlx
trlx.train(metric_fn=my_metric_fn, samples=samples, rewards=rewards, config=config)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| samples | List[str] | Yes | Full generated text strings |
| prompts | List[str] | Yes | Original prompt strings |
| outputs | List[str] | Yes | Generated completion strings only |
| **kwargs | Dict | No | Extra metadata |
Outputs
| Name | Type | Description |
|---|---|---|
| return | Dict[str, List[float]] | Mapping from metric name to list of per-sample scores |
Usage Examples
Sentiment Metric for ILQL
from transformers import pipeline
sentiment_fn = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=-1)
def metric_fn(samples, **kwargs):
"""Compute sentiment scores for evaluation."""
output = sentiment_fn(samples, batch_size=16)
sentiments = [o["score"] if o["label"] == "POSITIVE" else 1 - o["score"] for o in output]
return {"sentiments": sentiments}
import trlx
from trlx.data.default_configs import default_ilql_config
config = default_ilql_config()
trainer = trlx.train(
samples=samples,
rewards=rewards,
metric_fn=metric_fn,
eval_prompts=eval_prompts,
config=config,
)
Multi-Metric for SFT
def metric_fn(samples, prompts, outputs, **kwargs):
"""Compute multiple metrics for SFT evaluation."""
lengths = [len(o.split()) for o in outputs]
sentiments = compute_sentiment(samples)
return {
"output_length": lengths,
"sentiment": sentiments,
}