Implementation:NVIDIA NeMo Aligner CAI Utils
| Knowledge Sources | |
|---|---|
| Domains | Constitutional AI, Inference Utilities, Prompt Templating |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
A utility module providing remote inference functions and prompt template classes used by the Constitutional AI (CAI) dataset generation scripts in NeMo Aligner.
Description
cai_utils.py is the shared utility module for the CAI pipeline. It provides three main capabilities:
- Remote inference: The
remote_inferencefunction sends prompts to a locally hosted NeMo Megatron GPT inference service (viamegatron_gpt_eval.py), whileremote_inference_with_ngccalls NVIDIA NGC-hosted models through the NGC API. - Prompt templating: The
PromptTemplatebase class andUserAssistantPromptTemplatesubclass handle formatting of chat-style conversations with configurable role formats, BOS/EOS tokens, system messages, and response extraction patterns. They support single messages, multi-turn conversations, and batched inputs. - Dialog filtering: The
remove_long_dialogsfunction filters out conversations that exceed a maximum sequence length after tokenization, using multiprocessing for efficiency.
The module also includes a ChatTemplateHelper class for validating and collating batched chat messages into a format suitable for Megatron inference.
Usage
Use this module when:
- You need to call a local NeMo Megatron inference service or NVIDIA NGC API for text generation
- You need to format prompts using User/Assistant/System role templates (e.g.,
<extra_id_*>or Mistral-Instruct formats) - You need to extract model responses from generated text
- You need to filter out overly long dialogs from a dataset
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File:
examples/nlp/cai/cai_utils.py - Lines: 1-617
Signature
remote_inference:
def remote_inference(
prompt: Union[str, List[str], List[dict], List[List[dict]]],
port: int,
host: str,
temperature: Optional[float] = None,
greedy: Optional[bool] = None,
tokens_to_generate: Optional[int] = None,
min_tokens_to_generate: Optional[int] = None,
add_bos: Optional[bool] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
all_probs: Optional[bool] = None,
repetition_penalty: Optional[float] = None,
end_strings: Optional[Union[List[str], str]] = None,
):
remote_inference_with_ngc:
def remote_inference_with_ngc(
api_key: str,
url: str,
model: str,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
seed: Optional[int] = None,
):
UserAssistantPromptTemplate:
class UserAssistantPromptTemplate(PromptTemplate):
def __init__(
self,
user_format: str,
assistant_format: str,
system_format: Optional[str] = None,
system_default_message: Optional[str] = None,
bos_token: Optional[str] = None,
eos_token: Optional[str] = None,
response_extract_pattern: Optional[str] = None,
):
remove_long_dialogs:
def remove_long_dialogs(
input_file_path: str,
max_seq_length: int,
tokenizer_model: str,
tokenizer_library: str,
output_dir: str,
use_pool: bool,
):
Import
from cai_utils import remote_inference, remote_inference_with_ngc, UserAssistantPromptTemplate
from cai_utils import remove_long_dialogs, ChatTemplateHelper, PromptTemplate
I/O Contract
Inputs (remote_inference)
| Name | Type | Required | Description |
|---|---|---|---|
| prompt | Union[str, List[str], List[dict], List[List[dict]]] |
Yes | Text prompt(s) or chat message(s) to send for inference |
| port | int |
Yes | Port number of the inference service |
| host | str |
Yes | Hostname or IP address of the inference service |
| temperature | Optional[float] |
No | Sampling temperature (internally adds 0.00000001 to support temperature=0) |
| greedy | Optional[bool] |
No | Whether to use greedy decoding |
| tokens_to_generate | Optional[int] |
No | Maximum number of tokens to generate |
| min_tokens_to_generate | Optional[int] |
No | Minimum number of tokens to generate |
| add_bos | Optional[bool] |
No | Whether to add begin-of-sequence token |
| top_k | Optional[int] |
No | Top-k sampling parameter |
| top_p | Optional[float] |
No | Top-p (nucleus) sampling parameter |
| all_probs | Optional[bool] |
No | Whether to return all token probabilities |
| repetition_penalty | Optional[float] |
No | Repetition penalty factor |
| end_strings | Optional[Union[List[str], str]] |
No | String(s) that signal end of generation |
Inputs (remote_inference_with_ngc)
| Name | Type | Required | Description |
|---|---|---|---|
| api_key | str |
Yes | NGC API key for authentication |
| url | str |
Yes | NGC API endpoint URL (e.g., https://integrate.api.nvidia.com/v1/chat/completions) |
| model | str |
Yes | Model name on NGC (e.g., "mistralai/mixtral-8x7b-instruct-v0.1") |
| prompt | Optional[str] |
No | Single string prompt (mutually exclusive with messages) |
| messages | Optional[List[dict]] |
No | Chat messages list (mutually exclusive with prompt) |
| temperature | Optional[float] |
No | Sampling temperature |
| top_p | Optional[float] |
No | Top-p sampling parameter |
| max_tokens | Optional[int] |
No | Maximum tokens to generate |
| seed | Optional[int] |
No | Random seed for reproducibility |
Outputs
| Name | Type | Description |
|---|---|---|
| sentences (remote_inference) | List[str] |
List of generated text responses from the inference service |
| response_message (remote_inference_with_ngc) | str |
Single generated text response from NGC API |
Usage Examples
# Example 1: Remote inference with a local Megatron service
from cai_utils import remote_inference
responses = remote_inference(
prompt=["What is 2+2?", "Tell me a joke."],
port=5656,
host="localhost",
temperature=1.0,
tokens_to_generate=1024,
greedy=False,
end_strings=["<extra_id_1>"],
)
# Example 2: Remote inference with NGC API
from cai_utils import remote_inference_with_ngc
response = remote_inference_with_ngc(
api_key="your-ngc-api-key",
url="https://integrate.api.nvidia.com/v1/chat/completions",
model="mistralai/mixtral-8x7b-instruct-v0.1",
prompt="Calculate 3+4=?",
temperature=0,
max_tokens=1024,
)
# Example 3: Using UserAssistantPromptTemplate
from cai_utils import UserAssistantPromptTemplate
template = UserAssistantPromptTemplate(
user_format="<extra_id_1>User\n{MESSAGE}\n<extra_id_1>Assistant\n",
assistant_format="{MESSAGE}\n",
system_format="<extra_id_0>System\n{MESSAGE}\n",
system_default_message="",
eos_token="<extra_id_1>",
response_extract_pattern="<extra_id_1>Assistant\n",
)
prompt = template.format_messages([
{"role": "User", "content": "Calculate the sum of 2 and 3."},
{"role": "Assistant", "content": "The sum of 2 and 3 is 5."},
{"role": "User", "content": "Thank you! Now calculate 5 + 7."},
])
extracted = template.extract_response(generated_text)