Overview
Utility functions and data structures for splitting, batching, and padding experience data in the RLHF pipeline.
Description
This module provides the BufferItem dataclass and utility functions for converting between batched Experience objects and individual BufferItem entries. split_experience_batch unbinds a batched Experience tensor into a list of individual BufferItem instances. make_experience_batch performs the inverse operation, stacking individual items back into a batched Experience with zero-padding for variable-length action sequences. The private helper _zero_pad_sequences handles left or right zero-padding of tensor lists to a uniform length.
Usage
Use these utilities within the experience buffer infrastructure to convert between batched experience data (produced by the experience maker) and individual items (stored in the buffer), and to re-batch items during sampling.
Code Reference
Source Location
Signature
@dataclass
class BufferItem:
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
reward: torch.Tensor
kl: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]
def split_experience_batch(experience: Experience) -> List[BufferItem]:
def make_experience_batch(items: List[BufferItem]) -> Experience:
Import
from coati.experience_buffer.utils import BufferItem, split_experience_batch, make_experience_batch
I/O Contract
Inputs (split_experience_batch)
| Name |
Type |
Required |
Description
|
| experience |
Experience |
Yes |
A batched Experience object with shape (B, ...) tensors
|
Outputs (split_experience_batch)
| Name |
Type |
Description
|
| return |
List[BufferItem] |
List of individual BufferItem instances, one per batch element
|
Inputs (make_experience_batch)
| Name |
Type |
Required |
Description
|
| items |
List[BufferItem] |
Yes |
List of individual BufferItem instances to batch together
|
Outputs (make_experience_batch)
| Name |
Type |
Description
|
| return |
Experience |
A batched Experience object with zero-padded variable-length fields
|
Usage Examples
from coati.experience_buffer.utils import split_experience_batch, make_experience_batch
# Split a batch of experiences into individual items
items = split_experience_batch(experience_batch)
# Re-batch a subset of items (e.g., for sampling)
sampled_items = items[:8]
new_batch = make_experience_batch(sampled_items)
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.