Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:ContextualAI HALOs Sample Main

From Leeroopedia


Knowledge Sources
Domains NLP, Inference
Last Updated 2026-02-08 03:00 GMT

Overview

Concrete tool for batched text generation from a language model checkpoint provided by the train.sample module.

Description

The train/sample.py module provides a CLI-driven sampling script that loads a model checkpoint into vLLM's LLM engine with tensor parallelism, iterates through prompt datasets via SFTDataLoader, generates completions with configurable SamplingParams, and writes results to a streaming JSON file. It supports two output modes: alpacaeval mode (adds instruction field for AlpacaEval compatibility) and default mode (includes full multi-turn prompt).

The script validates dataset names against available get_* functions, applies the chat template to prompts, cleans generated text of special tokens, and properly destroys the distributed vLLM environment on completion.

Usage

Run as python -m train.sample /path/to/checkpoint --datasets alpacaeval --mode alpacaeval for evaluation, or python -m train.sample /path/to/checkpoint --datasets ultrachat --num_samples_per_prompt 4 for online alignment sampling.

Code Reference

Source Location

  • Repository: ContextualAI/HALOs
  • File: train/sample.py
  • Lines: L59-133 (main function), L136-156 (argparse)

Signature

def main(args: argparse.Namespace) -> None:
    """Sample completions from a language model using vLLM.

    Loads model into vLLM LLM engine, iterates through prompts via SFTDataLoader,
    generates completions, and writes results to JSON.

    Args (via argparse):
        model_path: str - Path to model checkpoint
        --datasets: List[str] - Dataset names (default ['alpacaeval'])
        --output_file: str - Output JSON path (default 'outputs.json')
        --gpu_count: int - Number of GPUs for tensor parallelism (default 1)
        --temperature: float - Sampling temperature (default 0.7)
        --top_p: float - Nucleus sampling threshold (default 0.95)
        --max_tokens: int - Max generation length (default 2048)
        --max_prompt_length: int - Max prompt length in tokens (default 512)
        --batch_size: int - Prompts per batch (default 1000)
        --seed: int - Random seed (default 0)
        --split: str - Dataset split (default 'test')
        --num_samples_per_prompt: int - Samples per prompt (default 1)
        --stop_token: str - Stop token (default '<|im_end|>')
        --mode: str - Output mode (default 'alpacaeval')
        --num_prompts: int - Total prompts to sample (default None = all)
        --num_skip: int - Prompts to skip (default 0)
        --num_epochs: int - Passes through data (default 1)

    Exit codes:
        0: Success (at least one prompt processed)
        1: No prompts processed
    """

Import

# Run as CLI module:
# python -m train.sample /path/to/model --datasets alpacaeval

# Or import directly:
from train.sample import main

I/O Contract

Inputs

Name Type Required Description
model_path str Yes Path to model checkpoint directory
datasets List[str] No Dataset names to sample from (default: alpacaeval)
gpu_count int No Number of GPUs for tensor parallelism (default: 1)
temperature float No Sampling temperature (default: 0.7)
top_p float No Top-p nucleus sampling (default: 0.95)
max_tokens int No Maximum generated tokens (default: 2048)
num_samples_per_prompt int No Completions per prompt (default: 1)

Outputs

Name Type Description
output JSON File List of dicts with: output, generator, dataset, prompt_id, sample_id, type="sample"
instruction field str Added in alpacaeval mode for AlpacaEval compatibility
prompt field List[Dict] Added in default mode as full multi-turn prompt

Usage Examples

AlpacaEval Sampling

python -m train.sample /models/llama3-8B-kto/FINAL \
    --datasets alpacaeval \
    --mode alpacaeval \
    --split test \
    --gpu_count 4 \
    --output_file alpacaeval_outputs.json

Online Alignment Sampling

python -m train.sample /models/llama3-8B-dpo-round1/FINAL \
    --datasets ultrafeedback_armorm \
    --mode default \
    --split train \
    --gpu_count 4 \
    --num_samples_per_prompt 4 \
    --num_prompts 512 \
    --output_file round2_samples.json

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment