Implementation:CarperAI Trlx Accelerate Base Datatypes
| Knowledge Sources | |
|---|---|
| Domains | Data_Structures, Reinforcement_Learning |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool providing dataclass definitions for prompt and RL data elements used by the Accelerate-based trainers.
Description
This module defines four simple dataclasses used across the Accelerate trainer stack: PromptElement (single prompt with text and tokens), PromptBatch (batched prompts), AccelerateRLElement (single RL output with tokens and rewards), and AccelerateRLBatchElement (batched RL outputs). All tensor fields are annotated with TensorType from torchtyping for shape documentation. These dataclasses serve as the standard data interchange format between pipelines and trainers.
Usage
Use these dataclasses when constructing or consuming prompt and RL data in the Accelerate-based training pipeline. They are used internally by PromptPipeline, AcceleratePPOTrainer, and other trainer classes.
Code Reference
Source Location
- Repository: CarperAI_Trlx
- File: trlx/data/accelerate_base_datatypes.py
- Lines: 1-68
Signature
@dataclass
class PromptElement:
"""Single prompt with text and tokenized representation."""
text: str
tokens: TensorType["num_tokens"]
@dataclass
class PromptBatch:
"""Batched prompts with text list and padded token tensor."""
text: Iterable[str]
tokens: TensorType["batch_size", "num_tokens"]
@dataclass
class AccelerateRLElement:
"""Single RL element with output tokens and per-token rewards."""
output_tokens: TensorType["output_size"]
rewards: TensorType["output_size"]
@dataclass
class AccelerateRLBatchElement:
"""Batched RL elements with output tokens and rewards tensors."""
output_tokens: TensorType["batch_size", "output_size"]
rewards: TensorType["batch_size", "output_size"]
Import
from trlx.data.accelerate_base_datatypes import (
PromptElement,
PromptBatch,
AccelerateRLElement,
AccelerateRLBatchElement,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| text | str / Iterable[str] | Yes | Raw prompt text(s) |
| tokens | TensorType | Yes | Tokenized representation (1D or 2D) |
| output_tokens | TensorType | Yes | Generated output token IDs |
| rewards | TensorType | Yes | Per-token reward values |
Outputs
| Name | Type | Description |
|---|---|---|
| PromptElement | dataclass | Single prompt container |
| PromptBatch | dataclass | Batched prompt container |
| AccelerateRLElement | dataclass | Single RL output container |
| AccelerateRLBatchElement | dataclass | Batched RL output container |
Usage Examples
Create Prompt and RL Elements
import torch
from trlx.data.accelerate_base_datatypes import (
PromptElement,
PromptBatch,
AccelerateRLElement,
)
# 1. Single prompt
prompt = PromptElement(
text="Once upon a time",
tokens=torch.tensor([7454, 2402, 257, 640]),
)
# 2. Batched prompts
batch = PromptBatch(
text=["Hello world", "How are you"],
tokens=torch.tensor([[15496, 995, 0], [2437, 389, 345]]),
)
# 3. RL element with rewards
rl_elem = AccelerateRLElement(
output_tokens=torch.tensor([257, 640, 11, 612]),
rewards=torch.tensor([0.0, 0.0, 0.0, 1.0]),
)