Overview
chatglm.py defines model wrapper classes for evaluating ChatGLM and ChatGLM2 models within the ColossalEval evaluation framework.
Description
This module provides two classes, ChatGLMModel and ChatGLM2Model, that extend HuggingFaceModel to handle the unique tokenization and generation behavior of THUDM's ChatGLM family of models. The classes implement custom prompt truncation logic, loss calculation on target tokens only (ignoring prompt tokens via IGNORE_INDEX), and generation with score extraction for single-choice questions. ChatGLM2Model additionally overrides the generate method to use build_prompt for proper formatting and returns logit scores at specific token indices for multiple-choice evaluation.
Usage
Use these classes within the ColossalEval framework to evaluate ChatGLM-6B or ChatGLM2-6B models on benchmarks. They are instantiated by the evaluation pipeline when the model type is detected as ChatGLM or ChatGLM2.
Code Reference
Source Location
Signature
class ChatGLMModel(HuggingFaceModel):
def _get_truncated_prompts(self, inputs: List[str],
max_new_tokens: int) -> List[str]
def get_loss(self, batch_prompt: List[str],
batch_target: List[List[str]],
calculate_overall_loss: bool = False) -> List[List[float]]
def _calculate_loss(self, input_ids_list: List[torch.LongTensor],
labels: List[torch.LongTensor]) -> List[float]
class ChatGLM2Model(ChatGLMModel):
def _get_truncated_prompts(self, inputs: List[str],
max_new_tokens: int) -> List[str]
def generate(self, inputs: List[str], max_new_tokens: int,
**kwargs) -> List[str]
def get_loss(self, batch_prompt: List[str],
batch_target: List[List[str]],
calculate_overall_loss: bool = False) -> List[List[float]]
Import
from colossal_eval.models.chatglm import ChatGLMModel, ChatGLM2Model
I/O Contract
Inputs (get_loss)
| Name |
Type |
Required |
Description
|
| batch_prompt |
List[str] |
Yes |
Batch of prompt strings without target answers
|
| batch_target |
List[List[str]] |
Yes |
Batch of target answers; each prompt may have multiple targets
|
| calculate_overall_loss |
bool |
No |
Whether to calculate overall loss (default: False)
|
Inputs (generate - ChatGLM2Model)
| Name |
Type |
Required |
Description
|
| inputs |
List[str] |
Yes |
List of input prompt strings
|
| max_new_tokens |
int |
Yes |
Maximum number of new tokens to generate
|
| **kwargs |
dict |
No |
Additional keyword arguments for model.generate()
|
Outputs (get_loss)
| Name |
Type |
Description
|
| losses_per_sample |
List[List[float]] |
Per-sample losses grouped by prompt, with sub-lists for multiple targets
|
| target_token_nums_per_sample |
List[List[int]] |
Number of target tokens per sample
|
| None |
None |
Third return value is always None (placeholder)
|
Outputs (generate - ChatGLM2Model)
| Name |
Type |
Description
|
| decoded_sequences |
List[str] |
Generated text strings
|
| scores |
Tensor |
Logit scores at choice token indices for multiple-choice evaluation
|
Usage Examples
from colossal_eval.models.chatglm import ChatGLM2Model
# Initialize ChatGLM2 model wrapper for evaluation
model = ChatGLM2Model(
model_path="THUDM/chatglm2-6b",
model_max_length=2048,
batch_size=4,
)
# Calculate loss on target tokens
losses, token_nums, _ = model.get_loss(
batch_prompt=["What is the capital of France?"],
batch_target=[["Paris", "The capital of France is Paris."]],
)
# Generate responses
responses, scores = model.generate(
inputs=["Explain quantum computing in simple terms."],
max_new_tokens=256,
)
Key Features
- Target-Only Loss - Calculates cross-entropy loss only on target tokens using IGNORE_INDEX (-100) masking
- Smart Truncation - Truncates prompts from both ends (keeping first and last halves) when exceeding model_max_length
- ChatGLM Token Handling - Handles ChatGLM-specific special tokens (gmask_id, bos_token) and build_inputs_with_special_tokens for ChatGLM v1
- ChatGLM2 Build Prompt - Uses build_prompt method for proper ChatGLM2 conversation formatting
- Choice Logit Extraction - Extracts logits at specific token indices for single-choice evaluation scoring
- Multi-Target Support - Supports multiple target answers per prompt, batching them separately for loss computation
- Per-Sample Loss - Uses reduction="none" in CrossEntropyLoss to return individual sample losses rather than batch mean
Related Pages