Implementation:Hpcaitech ColossalAI Generate With Actor
| Knowledge Sources | |
|---|---|
| Domains | Text_Generation, RLHF, Inference |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
generation.py provides sampling-based text generation utilities with temperature, top-k, and top-p control for autoregressive language model inference.
Description
This module implements the core text generation functions used during PPO rollouts and interactive chat. The generate function is the primary entry point, supporting sample-based generation with left-padding mode. It delegates to the internal _sample function, which iterates token-by-token, applying logits processors (TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper), multinomial sampling, and stop token detection. The module also provides generate_streaming and _sample_streaming for streaming generation that yields intermediate results at configurable intervals. Helper functions include _prepare_logits_processor for building the processor list, _is_sequence_finished for distributed-aware completion checking (with all_reduce for DP scenarios), update_model_kwargs_fn for updating KV cache and attention masks between steps, and prepare_inputs_fn for preparing model input dictionaries.
Usage
Use these generation functions during PPO/GRPO experience generation to produce text rollouts from the actor model. The generate function is called by NaiveExperienceMaker.make_experience and can also be used in standalone inference scripts.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/models/generation.py
- Lines: 1-435
Signature
@torch.inference_mode()
def generate(
model: Any,
input_ids: torch.Tensor,
tokenizer: PreTrainedTokenizer,
max_length: int,
num_beams: int = 1,
do_sample: bool = True,
early_stopping: bool = True,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs,
) -> torch.Tensor
@torch.inference_mode()
def generate_streaming(
model: Any,
input_ids: torch.Tensor,
tokenizer: PreTrainedTokenizer,
max_length: int,
num_beams: int = 1,
do_sample: bool = True,
early_stopping: bool = False,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable] = None,
update_model_kwargs_fn: Optional[Callable] = None,
**model_kwargs,
) -> Generator[torch.Tensor, None, None]
Helper Functions
def _prepare_logits_processor(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
) -> LogitsProcessorList
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool
def update_model_kwargs_fn(outputs: dict, new_mask, **model_kwargs) -> dict
def prepare_inputs_fn(input_ids: torch.Tensor, **model_kwargs) -> dict
Import
from coati.models.generation import generate, generate_streaming
I/O Contract
Inputs (generate)
| Name | Type | Required | Description |
|---|---|---|---|
| model | Any | Yes | The language model to use for generation |
| input_ids | torch.Tensor | Yes | Input token IDs (must be left-padded) |
| tokenizer | PreTrainedTokenizer | Yes | Tokenizer (must have padding_side="left") |
| max_length | int | Yes | Maximum total sequence length (prompt + generation) |
| num_beams | int | No | Number of beams (only 1 supported; default: 1) |
| do_sample | bool | No | Enable sampling (default: True) |
| early_stopping | bool | No | Stop when all sequences finish (default: True) |
| top_k | Optional[int] | No | Top-k filtering threshold |
| top_p | Optional[float] | No | Top-p (nucleus) sampling threshold |
| temperature | Optional[float] | No | Temperature for logits scaling |
| stop_token_ids | Optional[List[List[int]]] | No | List of stop token ID sequences (passed via model_kwargs) |
| max_new_tokens | Optional[int] | No | Maximum new tokens to generate (passed via model_kwargs) |
Outputs (generate)
| Name | Type | Description |
|---|---|---|
| sequences | torch.Tensor | Full token sequences (input_ids + generated tokens) of shape [batch_size, total_length] |
Usage Examples
from coati.models.generation import generate, generate_streaming
# Basic generation
tokenizer.padding_side = "left"
sequences = generate(
model=actor_model,
input_ids=input_ids,
tokenizer=tokenizer,
max_length=2048,
do_sample=True,
top_p=0.95,
temperature=0.7,
stop_token_ids=[[tokenizer.eos_token_id]],
)
# Streaming generation
for partial_output in generate_streaming(
model=actor_model,
input_ids=input_ids,
tokenizer=tokenizer,
max_length=2048,
do_sample=True,
temperature=0.7,
):
# Process intermediate results
decoded = tokenizer.decode(partial_output[0], skip_special_tokens=True)