Implementation:Hpcaitech ColossalAI BaseModel
| Knowledge Sources | |
|---|---|
| Domains | Deep Learning, Model Architecture, RLHF |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Base model class for ColossalChat critic and reward models that wraps a HuggingFace pretrained model.
Description
BaseModel is a PyTorch nn.Module subclass that serves as the foundation for both the Critic and RewardModel classes in ColossalChat. It initializes a HuggingFace AutoModel from either a pretrained checkpoint path, a PretrainedConfig object, or both. During initialization, it performs a forward pass with a dummy input to determine the last hidden state dimensionality, which is stored for use by subclass value heads. It also supports Flash Attention 2 by automatically moving the model to CUDA when that option is specified.
Usage
Use this class as a base when building value-head models (critics, reward models) on top of pretrained transformers. It is not intended to be instantiated directly; instead, use the Critic or RewardModel subclasses.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/models/base.py
- Lines: 1-57
Signature
class BaseModel(nn.Module):
def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] = None, **kwargs) -> None:
def resize_token_embeddings(self, *args, **kwargs):
Import
from coati.models.base import BaseModel
I/O Contract
Inputs (__init__)
| Name | Type | Required | Description |
|---|---|---|---|
| pretrained | str | No | Path to a pretrained HuggingFace model |
| config | PretrainedConfig | No | PretrainedConfig for model initialization (at least one of pretrained or config must be provided) |
| **kwargs | dict | No | Additional keyword arguments passed to AutoModel.from_pretrained |
Outputs
| Name | Type | Description |
|---|---|---|
| self.model | AutoModel | The wrapped HuggingFace transformer model |
| self.last_hidden_state_size | int | Dimensionality of the last hidden state, used by subclass value heads |
Usage Examples
from coati.models.base import BaseModel
# Typically used via subclasses:
from coati.models.critic import Critic
from coati.models.reward_model import RewardModel
# Initialize a Critic from a pretrained model
critic = Critic(pretrained="meta-llama/Llama-2-7b-hf")
# The base model provides last_hidden_state_size
print(critic.last_hidden_state_size) # e.g., 4096
# Resize token embeddings (e.g., after adding special tokens)
critic.resize_token_embeddings(new_num_tokens=32001)