Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx Accelerate Base Datatypes

From Leeroopedia


Knowledge Sources
Domains Data_Structures, Reinforcement_Learning
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool providing dataclass definitions for prompt and RL data elements used by the Accelerate-based trainers.

Description

This module defines four simple dataclasses used across the Accelerate trainer stack: PromptElement (single prompt with text and tokens), PromptBatch (batched prompts), AccelerateRLElement (single RL output with tokens and rewards), and AccelerateRLBatchElement (batched RL outputs). All tensor fields are annotated with TensorType from torchtyping for shape documentation. These dataclasses serve as the standard data interchange format between pipelines and trainers.

Usage

Use these dataclasses when constructing or consuming prompt and RL data in the Accelerate-based training pipeline. They are used internally by PromptPipeline, AcceleratePPOTrainer, and other trainer classes.

Code Reference

Source Location

Signature

@dataclass
class PromptElement:
    """Single prompt with text and tokenized representation."""
    text: str
    tokens: TensorType["num_tokens"]


@dataclass
class PromptBatch:
    """Batched prompts with text list and padded token tensor."""
    text: Iterable[str]
    tokens: TensorType["batch_size", "num_tokens"]


@dataclass
class AccelerateRLElement:
    """Single RL element with output tokens and per-token rewards."""
    output_tokens: TensorType["output_size"]
    rewards: TensorType["output_size"]


@dataclass
class AccelerateRLBatchElement:
    """Batched RL elements with output tokens and rewards tensors."""
    output_tokens: TensorType["batch_size", "output_size"]
    rewards: TensorType["batch_size", "output_size"]

Import

from trlx.data.accelerate_base_datatypes import (
    PromptElement,
    PromptBatch,
    AccelerateRLElement,
    AccelerateRLBatchElement,
)

I/O Contract

Inputs

Name Type Required Description
text str / Iterable[str] Yes Raw prompt text(s)
tokens TensorType Yes Tokenized representation (1D or 2D)
output_tokens TensorType Yes Generated output token IDs
rewards TensorType Yes Per-token reward values

Outputs

Name Type Description
PromptElement dataclass Single prompt container
PromptBatch dataclass Batched prompt container
AccelerateRLElement dataclass Single RL output container
AccelerateRLBatchElement dataclass Batched RL output container

Usage Examples

Create Prompt and RL Elements

import torch
from trlx.data.accelerate_base_datatypes import (
    PromptElement,
    PromptBatch,
    AccelerateRLElement,
)

# 1. Single prompt
prompt = PromptElement(
    text="Once upon a time",
    tokens=torch.tensor([7454, 2402, 257, 640]),
)

# 2. Batched prompts
batch = PromptBatch(
    text=["Hello world", "How are you"],
    tokens=torch.tensor([[15496, 995, 0], [2437, 389, 345]]),
)

# 3. RL element with rewards
rl_elem = AccelerateRLElement(
    output_tokens=torch.tensor([257, 640, 11, 612]),
    rewards=torch.tensor([0.0, 0.0, 0.0, 1.0]),
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment