Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI Critic

From Leeroopedia


Knowledge Sources
Domains Reinforcement Learning, RLHF, Model Architecture
Last Updated 2026-02-09 00:00 GMT

Overview

Critic model for PPO-based RLHF training that produces per-token value estimates.

Description

The Critic class extends BaseModel by adding a linear value head that maps the last hidden state of a pretrained transformer to scalar value estimates for each token position. The forward method processes input token IDs through the base transformer, extracts the last hidden states, and applies the value head to produce per-position value predictions with shape (B, sequence_length). It also provides accessor methods for input and output embeddings.

Usage

Use this model as the critic/value function in PPO training within the ColossalChat RLHF pipeline. It estimates the expected future reward at each token position, which is used to compute advantages for the policy gradient update.

Code Reference

Source Location

Signature

class Critic(BaseModel):
    def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] = None, **kwargs) -> None:

    def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:

    def get_input_embeddings(self):

    def get_output_embeddings(self):

Import

from coati.models.critic import Critic

I/O Contract

Inputs (forward)

Name Type Required Description
input_ids torch.LongTensor Yes Input token IDs of shape (B, S)
attention_mask torch.Tensor No Attention mask of shape (B, S)

Outputs (forward)

Name Type Description
values torch.Tensor Per-position value estimates of shape (B, sequence_length)

Usage Examples

from coati.models.critic import Critic
import torch

# Initialize critic from a pretrained model
critic = Critic(pretrained="meta-llama/Llama-2-7b-hf")
critic = critic.cuda()

# Forward pass
input_ids = torch.randint(0, 32000, (2, 128)).cuda()
attention_mask = torch.ones(2, 128).cuda()
values = critic(input_ids, attention_mask=attention_mask)
print(values.shape)  # (2, 128)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment