Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx Metric Function Interface

From Leeroopedia


Knowledge Sources
Domains Evaluation, NLP
Last Updated 2026-02-07 16:00 GMT

Overview

Interface specification for user-defined evaluation metric functions in trlx training.

Description

The metric function is a user-defined callable that computes evaluation statistics on batches of generated text during periodic evaluation. This is a Pattern Doc — it documents the interface users must implement. The function receives generated samples, prompts, and outputs, and returns a dictionary mapping metric names to lists of per-sample scores. trlx calls this function in AccelerateRLTrainer.evaluate() at configurable intervals.

Usage

Implement this interface when using trlx.train() for any training method (PPO, ILQL, SFT) and you want to monitor generation quality. Pass as the metric_fn argument. Particularly important for ILQL and SFT where no live reward function is provided.

Code Reference

Source Location

  • Repository: trlx
  • File: examples/ilql_sentiments.py
  • Lines: L31-33 (canonical ILQL example)

Interface Specification

def metric_fn(
    samples: List[str],       # Full generated text (prompt + output)
    prompts: List[str],       # Original prompt text
    outputs: List[str],       # Generated output only
    **kwargs                  # Additional metadata
) -> Dict[str, List[float]]:
    """
    Compute evaluation metrics on generated text.

    Args:
        samples: Complete generated strings.
        prompts: The original prompt strings.
        outputs: The generated completion strings only.

    Returns:
        Dict mapping metric names to lists of per-sample float scores.
        Example: {"sentiment": [0.95, 0.87, ...], "length": [42, 38, ...]}
    """
    ...

Import

# No import needed — this is a user-defined function
import trlx
trlx.train(metric_fn=my_metric_fn, samples=samples, rewards=rewards, config=config)

I/O Contract

Inputs

Name Type Required Description
samples List[str] Yes Full generated text strings
prompts List[str] Yes Original prompt strings
outputs List[str] Yes Generated completion strings only
**kwargs Dict No Extra metadata

Outputs

Name Type Description
return Dict[str, List[float]] Mapping from metric name to list of per-sample scores

Usage Examples

Sentiment Metric for ILQL

from transformers import pipeline

sentiment_fn = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=-1)

def metric_fn(samples, **kwargs):
    """Compute sentiment scores for evaluation."""
    output = sentiment_fn(samples, batch_size=16)
    sentiments = [o["score"] if o["label"] == "POSITIVE" else 1 - o["score"] for o in output]
    return {"sentiments": sentiments}

import trlx
from trlx.data.default_configs import default_ilql_config

config = default_ilql_config()
trainer = trlx.train(
    samples=samples,
    rewards=rewards,
    metric_fn=metric_fn,
    eval_prompts=eval_prompts,
    config=config,
)

Multi-Metric for SFT

def metric_fn(samples, prompts, outputs, **kwargs):
    """Compute multiple metrics for SFT evaluation."""
    lengths = [len(o.split()) for o in outputs]
    sentiments = compute_sentiment(samples)
    return {
        "output_length": lengths,
        "sentiment": sentiments,
    }

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment