Implementation:Hpcaitech ColossalAI PolicyLoss
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Optimization |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for computing the GRPO/PPO clipped policy gradient loss with KL penalty, provided by ColossalChat.
Description
PolicyLoss is a PyTorch module that implements the clipped surrogate loss with a per-token KL divergence penalty. It supports both sample-level and token-level loss aggregation and includes importance ratio clipping to prevent large policy updates.
Usage
Instantiate with clipping and KL parameters, then call forward() with log probabilities, old log probabilities, advantages, and action masks.
Code Reference
Source Location
- Repository: ColossalAI
- File: applications/ColossalChat/coati/distributed/loss.py
- Lines: 8-70
Signature
class PolicyLoss(nn.Module):
def __init__(
self,
clip_eps_low: float = 0.2,
clip_eps_high: float = 0.2,
beta: float = 0.01,
loss_variation: str = "sample_level",
adv: str = "GRPO",
) -> None:
"""
Args:
clip_eps_low: Lower clipping epsilon (default: 0.2)
clip_eps_high: Upper clipping epsilon (default: 0.2)
beta: KL penalty coefficient (default: 0.01)
loss_variation: "sample_level" or "token_level"
adv: Advantage type ("GRPO")
"""
def forward(
self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
per_token_kl: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
total_effective_tokens_in_batch: torch.Tensor = None,
) -> torch.Tensor:
"""Compute clipped policy loss with KL penalty."""
Import
from coati.distributed.loss import PolicyLoss
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| log_probs | Tensor | Yes | Current policy log probabilities |
| old_log_probs | Tensor | Yes | Old policy log probabilities (from producer) |
| advantages | Tensor | Yes | Group-relative advantages |
| per_token_kl | Tensor | Yes | Per-token KL divergence from reference |
| action_mask | Tensor | No | Mask for valid action tokens |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Tensor | Scalar policy gradient loss |
Usage Examples
from coati.distributed.loss import PolicyLoss
policy_loss_fn = PolicyLoss(
clip_eps_low=0.2,
clip_eps_high=0.2,
beta=0.01,
loss_variation="sample_level",
)
loss = policy_loss_fn(
log_probs=current_log_probs,
old_log_probs=old_log_probs,
advantages=group_advantages,
per_token_kl=kl_divergence,
action_mask=action_mask,
)
Related Pages
Implements Principle
Environment and Heuristic Links
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment