Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner CAI Utils

From Leeroopedia


Knowledge Sources
Domains Constitutional AI, Inference Utilities, Prompt Templating
Last Updated 2026-02-08 00:00 GMT

Overview

A utility module providing remote inference functions and prompt template classes used by the Constitutional AI (CAI) dataset generation scripts in NeMo Aligner.

Description

cai_utils.py is the shared utility module for the CAI pipeline. It provides three main capabilities:

  1. Remote inference: The remote_inference function sends prompts to a locally hosted NeMo Megatron GPT inference service (via megatron_gpt_eval.py), while remote_inference_with_ngc calls NVIDIA NGC-hosted models through the NGC API.
  2. Prompt templating: The PromptTemplate base class and UserAssistantPromptTemplate subclass handle formatting of chat-style conversations with configurable role formats, BOS/EOS tokens, system messages, and response extraction patterns. They support single messages, multi-turn conversations, and batched inputs.
  3. Dialog filtering: The remove_long_dialogs function filters out conversations that exceed a maximum sequence length after tokenization, using multiprocessing for efficiency.

The module also includes a ChatTemplateHelper class for validating and collating batched chat messages into a format suitable for Megatron inference.

Usage

Use this module when:

  • You need to call a local NeMo Megatron inference service or NVIDIA NGC API for text generation
  • You need to format prompts using User/Assistant/System role templates (e.g., <extra_id_*> or Mistral-Instruct formats)
  • You need to extract model responses from generated text
  • You need to filter out overly long dialogs from a dataset

Code Reference

Source Location

Signature

remote_inference:

def remote_inference(
    prompt: Union[str, List[str], List[dict], List[List[dict]]],
    port: int,
    host: str,
    temperature: Optional[float] = None,
    greedy: Optional[bool] = None,
    tokens_to_generate: Optional[int] = None,
    min_tokens_to_generate: Optional[int] = None,
    add_bos: Optional[bool] = None,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    all_probs: Optional[bool] = None,
    repetition_penalty: Optional[float] = None,
    end_strings: Optional[Union[List[str], str]] = None,
):

remote_inference_with_ngc:

def remote_inference_with_ngc(
    api_key: str,
    url: str,
    model: str,
    prompt: Optional[str] = None,
    messages: Optional[List[dict]] = None,
    temperature: Optional[float] = None,
    top_p: Optional[float] = None,
    max_tokens: Optional[int] = None,
    seed: Optional[int] = None,
):

UserAssistantPromptTemplate:

class UserAssistantPromptTemplate(PromptTemplate):
    def __init__(
        self,
        user_format: str,
        assistant_format: str,
        system_format: Optional[str] = None,
        system_default_message: Optional[str] = None,
        bos_token: Optional[str] = None,
        eos_token: Optional[str] = None,
        response_extract_pattern: Optional[str] = None,
    ):

remove_long_dialogs:

def remove_long_dialogs(
    input_file_path: str,
    max_seq_length: int,
    tokenizer_model: str,
    tokenizer_library: str,
    output_dir: str,
    use_pool: bool,
):

Import

from cai_utils import remote_inference, remote_inference_with_ngc, UserAssistantPromptTemplate
from cai_utils import remove_long_dialogs, ChatTemplateHelper, PromptTemplate

I/O Contract

Inputs (remote_inference)

Name Type Required Description
prompt Union[str, List[str], List[dict], List[List[dict]]] Yes Text prompt(s) or chat message(s) to send for inference
port int Yes Port number of the inference service
host str Yes Hostname or IP address of the inference service
temperature Optional[float] No Sampling temperature (internally adds 0.00000001 to support temperature=0)
greedy Optional[bool] No Whether to use greedy decoding
tokens_to_generate Optional[int] No Maximum number of tokens to generate
min_tokens_to_generate Optional[int] No Minimum number of tokens to generate
add_bos Optional[bool] No Whether to add begin-of-sequence token
top_k Optional[int] No Top-k sampling parameter
top_p Optional[float] No Top-p (nucleus) sampling parameter
all_probs Optional[bool] No Whether to return all token probabilities
repetition_penalty Optional[float] No Repetition penalty factor
end_strings Optional[Union[List[str], str]] No String(s) that signal end of generation

Inputs (remote_inference_with_ngc)

Name Type Required Description
api_key str Yes NGC API key for authentication
url str Yes NGC API endpoint URL (e.g., https://integrate.api.nvidia.com/v1/chat/completions)
model str Yes Model name on NGC (e.g., "mistralai/mixtral-8x7b-instruct-v0.1")
prompt Optional[str] No Single string prompt (mutually exclusive with messages)
messages Optional[List[dict]] No Chat messages list (mutually exclusive with prompt)
temperature Optional[float] No Sampling temperature
top_p Optional[float] No Top-p sampling parameter
max_tokens Optional[int] No Maximum tokens to generate
seed Optional[int] No Random seed for reproducibility

Outputs

Name Type Description
sentences (remote_inference) List[str] List of generated text responses from the inference service
response_message (remote_inference_with_ngc) str Single generated text response from NGC API

Usage Examples

# Example 1: Remote inference with a local Megatron service
from cai_utils import remote_inference

responses = remote_inference(
    prompt=["What is 2+2?", "Tell me a joke."],
    port=5656,
    host="localhost",
    temperature=1.0,
    tokens_to_generate=1024,
    greedy=False,
    end_strings=["<extra_id_1>"],
)

# Example 2: Remote inference with NGC API
from cai_utils import remote_inference_with_ngc

response = remote_inference_with_ngc(
    api_key="your-ngc-api-key",
    url="https://integrate.api.nvidia.com/v1/chat/completions",
    model="mistralai/mixtral-8x7b-instruct-v0.1",
    prompt="Calculate 3+4=?",
    temperature=0,
    max_tokens=1024,
)

# Example 3: Using UserAssistantPromptTemplate
from cai_utils import UserAssistantPromptTemplate

template = UserAssistantPromptTemplate(
    user_format="<extra_id_1>User\n{MESSAGE}\n<extra_id_1>Assistant\n",
    assistant_format="{MESSAGE}\n",
    system_format="<extra_id_0>System\n{MESSAGE}\n",
    system_default_message="",
    eos_token="<extra_id_1>",
    response_extract_pattern="<extra_id_1>Assistant\n",
)

prompt = template.format_messages([
    {"role": "User", "content": "Calculate the sum of 2 and 3."},
    {"role": "Assistant", "content": "The sum of 2 and 3 is 5."},
    {"role": "User", "content": "Thank you! Now calculate 5 + 7."},
])

extracted = template.extract_response(generated_text)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment