Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI BaseModel

From Leeroopedia


Knowledge Sources
Domains Deep Learning, Model Architecture, RLHF
Last Updated 2026-02-09 00:00 GMT

Overview

Base model class for ColossalChat critic and reward models that wraps a HuggingFace pretrained model.

Description

BaseModel is a PyTorch nn.Module subclass that serves as the foundation for both the Critic and RewardModel classes in ColossalChat. It initializes a HuggingFace AutoModel from either a pretrained checkpoint path, a PretrainedConfig object, or both. During initialization, it performs a forward pass with a dummy input to determine the last hidden state dimensionality, which is stored for use by subclass value heads. It also supports Flash Attention 2 by automatically moving the model to CUDA when that option is specified.

Usage

Use this class as a base when building value-head models (critics, reward models) on top of pretrained transformers. It is not intended to be instantiated directly; instead, use the Critic or RewardModel subclasses.

Code Reference

Source Location

Signature

class BaseModel(nn.Module):
    def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] = None, **kwargs) -> None:

    def resize_token_embeddings(self, *args, **kwargs):

Import

from coati.models.base import BaseModel

I/O Contract

Inputs (__init__)

Name Type Required Description
pretrained str No Path to a pretrained HuggingFace model
config PretrainedConfig No PretrainedConfig for model initialization (at least one of pretrained or config must be provided)
**kwargs dict No Additional keyword arguments passed to AutoModel.from_pretrained

Outputs

Name Type Description
self.model AutoModel The wrapped HuggingFace transformer model
self.last_hidden_state_size int Dimensionality of the last hidden state, used by subclass value heads

Usage Examples

from coati.models.base import BaseModel

# Typically used via subclasses:
from coati.models.critic import Critic
from coati.models.reward_model import RewardModel

# Initialize a Critic from a pretrained model
critic = Critic(pretrained="meta-llama/Llama-2-7b-hf")

# The base model provides last_hidden_state_size
print(critic.last_hidden_state_size)  # e.g., 4096

# Resize token embeddings (e.g., after adding special tokens)
critic.resize_token_embeddings(new_num_tokens=32001)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment