Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI Generate With Actor

From Leeroopedia


Knowledge Sources
Domains Text_Generation, RLHF, Inference
Last Updated 2026-02-09 00:00 GMT

Overview

generation.py provides sampling-based text generation utilities with temperature, top-k, and top-p control for autoregressive language model inference.

Description

This module implements the core text generation functions used during PPO rollouts and interactive chat. The generate function is the primary entry point, supporting sample-based generation with left-padding mode. It delegates to the internal _sample function, which iterates token-by-token, applying logits processors (TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper), multinomial sampling, and stop token detection. The module also provides generate_streaming and _sample_streaming for streaming generation that yields intermediate results at configurable intervals. Helper functions include _prepare_logits_processor for building the processor list, _is_sequence_finished for distributed-aware completion checking (with all_reduce for DP scenarios), update_model_kwargs_fn for updating KV cache and attention masks between steps, and prepare_inputs_fn for preparing model input dictionaries.

Usage

Use these generation functions during PPO/GRPO experience generation to produce text rollouts from the actor model. The generate function is called by NaiveExperienceMaker.make_experience and can also be used in standalone inference scripts.

Code Reference

Source Location

Signature

@torch.inference_mode()
def generate(
    model: Any,
    input_ids: torch.Tensor,
    tokenizer: PreTrainedTokenizer,
    max_length: int,
    num_beams: int = 1,
    do_sample: bool = True,
    early_stopping: bool = True,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    temperature: Optional[float] = None,
    prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
    update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
    **model_kwargs,
) -> torch.Tensor

@torch.inference_mode()
def generate_streaming(
    model: Any,
    input_ids: torch.Tensor,
    tokenizer: PreTrainedTokenizer,
    max_length: int,
    num_beams: int = 1,
    do_sample: bool = True,
    early_stopping: bool = False,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    temperature: Optional[float] = None,
    prepare_inputs_fn: Optional[Callable] = None,
    update_model_kwargs_fn: Optional[Callable] = None,
    **model_kwargs,
) -> Generator[torch.Tensor, None, None]

Helper Functions

def _prepare_logits_processor(
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    temperature: Optional[float] = None,
) -> LogitsProcessorList

def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool

def update_model_kwargs_fn(outputs: dict, new_mask, **model_kwargs) -> dict

def prepare_inputs_fn(input_ids: torch.Tensor, **model_kwargs) -> dict

Import

from coati.models.generation import generate, generate_streaming

I/O Contract

Inputs (generate)

Name Type Required Description
model Any Yes The language model to use for generation
input_ids torch.Tensor Yes Input token IDs (must be left-padded)
tokenizer PreTrainedTokenizer Yes Tokenizer (must have padding_side="left")
max_length int Yes Maximum total sequence length (prompt + generation)
num_beams int No Number of beams (only 1 supported; default: 1)
do_sample bool No Enable sampling (default: True)
early_stopping bool No Stop when all sequences finish (default: True)
top_k Optional[int] No Top-k filtering threshold
top_p Optional[float] No Top-p (nucleus) sampling threshold
temperature Optional[float] No Temperature for logits scaling
stop_token_ids Optional[List[List[int]]] No List of stop token ID sequences (passed via model_kwargs)
max_new_tokens Optional[int] No Maximum new tokens to generate (passed via model_kwargs)

Outputs (generate)

Name Type Description
sequences torch.Tensor Full token sequences (input_ids + generated tokens) of shape [batch_size, total_length]

Usage Examples

from coati.models.generation import generate, generate_streaming

# Basic generation
tokenizer.padding_side = "left"
sequences = generate(
    model=actor_model,
    input_ids=input_ids,
    tokenizer=tokenizer,
    max_length=2048,
    do_sample=True,
    top_p=0.95,
    temperature=0.7,
    stop_token_ids=[[tokenizer.eos_token_id]],
)

# Streaming generation
for partial_output in generate_streaming(
    model=actor_model,
    input_ids=input_ids,
    tokenizer=tokenizer,
    max_length=2048,
    do_sample=True,
    temperature=0.7,
):
    # Process intermediate results
    decoded = tokenizer.decode(partial_output[0], skip_special_tokens=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment