Implementation:ContextualAI HALOs Sample Main
| Knowledge Sources | |
|---|---|
| Domains | NLP, Inference |
| Last Updated | 2026-02-08 03:00 GMT |
Overview
Concrete tool for batched text generation from a language model checkpoint provided by the train.sample module.
Description
The train/sample.py module provides a CLI-driven sampling script that loads a model checkpoint into vLLM's LLM engine with tensor parallelism, iterates through prompt datasets via SFTDataLoader, generates completions with configurable SamplingParams, and writes results to a streaming JSON file. It supports two output modes: alpacaeval mode (adds instruction field for AlpacaEval compatibility) and default mode (includes full multi-turn prompt).
The script validates dataset names against available get_* functions, applies the chat template to prompts, cleans generated text of special tokens, and properly destroys the distributed vLLM environment on completion.
Usage
Run as python -m train.sample /path/to/checkpoint --datasets alpacaeval --mode alpacaeval for evaluation, or python -m train.sample /path/to/checkpoint --datasets ultrachat --num_samples_per_prompt 4 for online alignment sampling.
Code Reference
Source Location
- Repository: ContextualAI/HALOs
- File: train/sample.py
- Lines: L59-133 (main function), L136-156 (argparse)
Signature
def main(args: argparse.Namespace) -> None:
"""Sample completions from a language model using vLLM.
Loads model into vLLM LLM engine, iterates through prompts via SFTDataLoader,
generates completions, and writes results to JSON.
Args (via argparse):
model_path: str - Path to model checkpoint
--datasets: List[str] - Dataset names (default ['alpacaeval'])
--output_file: str - Output JSON path (default 'outputs.json')
--gpu_count: int - Number of GPUs for tensor parallelism (default 1)
--temperature: float - Sampling temperature (default 0.7)
--top_p: float - Nucleus sampling threshold (default 0.95)
--max_tokens: int - Max generation length (default 2048)
--max_prompt_length: int - Max prompt length in tokens (default 512)
--batch_size: int - Prompts per batch (default 1000)
--seed: int - Random seed (default 0)
--split: str - Dataset split (default 'test')
--num_samples_per_prompt: int - Samples per prompt (default 1)
--stop_token: str - Stop token (default '<|im_end|>')
--mode: str - Output mode (default 'alpacaeval')
--num_prompts: int - Total prompts to sample (default None = all)
--num_skip: int - Prompts to skip (default 0)
--num_epochs: int - Passes through data (default 1)
Exit codes:
0: Success (at least one prompt processed)
1: No prompts processed
"""
Import
# Run as CLI module:
# python -m train.sample /path/to/model --datasets alpacaeval
# Or import directly:
from train.sample import main
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_path | str | Yes | Path to model checkpoint directory |
| datasets | List[str] | No | Dataset names to sample from (default: alpacaeval) |
| gpu_count | int | No | Number of GPUs for tensor parallelism (default: 1) |
| temperature | float | No | Sampling temperature (default: 0.7) |
| top_p | float | No | Top-p nucleus sampling (default: 0.95) |
| max_tokens | int | No | Maximum generated tokens (default: 2048) |
| num_samples_per_prompt | int | No | Completions per prompt (default: 1) |
Outputs
| Name | Type | Description |
|---|---|---|
| output JSON | File | List of dicts with: output, generator, dataset, prompt_id, sample_id, type="sample" |
| instruction field | str | Added in alpacaeval mode for AlpacaEval compatibility |
| prompt field | List[Dict] | Added in default mode as full multi-turn prompt |
Usage Examples
AlpacaEval Sampling
python -m train.sample /models/llama3-8B-kto/FINAL \
--datasets alpacaeval \
--mode alpacaeval \
--split test \
--gpu_count 4 \
--output_file alpacaeval_outputs.json
Online Alignment Sampling
python -m train.sample /models/llama3-8B-dpo-round1/FINAL \
--datasets ultrafeedback_armorm \
--mode default \
--split train \
--gpu_count 4 \
--num_samples_per_prompt 4 \
--num_prompts 512 \
--output_file round2_samples.json