Implementation:Hpcaitech ColossalAI Critic
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement Learning, RLHF, Model Architecture |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Critic model for PPO-based RLHF training that produces per-token value estimates.
Description
The Critic class extends BaseModel by adding a linear value head that maps the last hidden state of a pretrained transformer to scalar value estimates for each token position. The forward method processes input token IDs through the base transformer, extracts the last hidden states, and applies the value head to produce per-position value predictions with shape (B, sequence_length). It also provides accessor methods for input and output embeddings.
Usage
Use this model as the critic/value function in PPO training within the ColossalChat RLHF pipeline. It estimates the expected future reward at each token position, which is used to compute advantages for the policy gradient update.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/models/critic.py
- Lines: 1-40
Signature
class Critic(BaseModel):
def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] = None, **kwargs) -> None:
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def get_input_embeddings(self):
def get_output_embeddings(self):
Import
from coati.models.critic import Critic
I/O Contract
Inputs (forward)
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | Yes | Input token IDs of shape (B, S) |
| attention_mask | torch.Tensor | No | Attention mask of shape (B, S) |
Outputs (forward)
| Name | Type | Description |
|---|---|---|
| values | torch.Tensor | Per-position value estimates of shape (B, sequence_length) |
Usage Examples
from coati.models.critic import Critic
import torch
# Initialize critic from a pretrained model
critic = Critic(pretrained="meta-llama/Llama-2-7b-hf")
critic = critic.cuda()
# Forward pass
input_ids = torch.randint(0, 32000, (2, 128)).cuda()
attention_mask = torch.ones(2, 128).cuda()
values = critic(input_ids, attention_mask=attention_mask)
print(values.shape) # (2, 128)