Implementation:Hpcaitech ColossalAI Code Reward Testing Util
| Knowledge Sources | |
|---|---|
| Domains | Code_Evaluation, RLHF, Reward_Computation |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
testing_util.py executes and validates AI-generated code solutions against test cases in a sandboxed environment, supporting both call-based and standard-input evaluation patterns.
Description
This module provides the run_test function, the core entry point for code reward evaluation. It compiles generated code solutions into a RuntimeModule (via pyext), invokes them with test inputs, and compares outputs against expected results using multiple comparison strategies: direct equality, string stripping, float approximation via numpy.allclose, and set-based comparison. A reliability_guard function disables destructive OS operations (fork, kill, file deletion, subprocess, etc.) to create a safety sandbox. The module supports two code evaluation modes via the CODE_TYPE enum: call-based (function invocation with JSON-parsed arguments) and standard-input (stdin/stdout patching). Adapted from the verl/PRIME projects.
Usage
Use this module when computing code-based rewards in RLHF training for code generation tasks. It is invoked by the reward function pipeline to determine whether generated code solutions produce correct outputs, directly feeding into the reward signal for RL training.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py
- Lines: 1-674
Key Functions
def run_test(
in_outs,
test=None,
debug=False,
timeout=15,
run_all_tests=False,
) -> Tuple[List, Dict]
def reliability_guard(maximum_memory_bytes=None) -> None
def call_method(method, inputs)
def custom_compare_(output, ground_truth) -> bool
def stripped_string_compare(s1, s2) -> bool
def truncatefn(s, length=300) -> str
Key Classes
class CODE_TYPE(Enum):
call_based = 0
standard_input = 1
class Capturing(list):
"""Context manager that captures stdout output."""
def __enter__(self) -> "Capturing"
def __exit__(self, *args) -> None
Import
from coati.distributed.reward.code_reward.testing_util import run_test, reliability_guard
I/O Contract
Inputs (run_test)
| Name | Type | Required | Description |
|---|---|---|---|
| in_outs | dict | Yes | Dictionary with "inputs", "outputs", and optionally "fn_name" keys defining test cases |
| test | str | Yes | The generated code solution to evaluate |
| debug | bool | No | Enable verbose debug output (default: False) |
| timeout | int | No | Timeout in seconds per test case (default: 15) |
| run_all_tests | bool | No | Whether to continue running after first failure (default: False) |
Outputs (run_test)
| Name | Type | Description |
|---|---|---|
| results | List | List of test results: True (pass), -1 (runtime error/timeout), -2 (compilation error) |
| error_info | Dict | Error metadata dict with "error", "traceback", "output", "expected", "inputs", or "error_message" keys; empty dict if all tests pass |
Usage Examples
from coati.distributed.reward.code_reward.testing_util import run_test
# Call-based test
in_outs = {
"fn_name": "add",
"inputs": ["1\n2", "3\n4"],
"outputs": ["3", "7"],
}
test_code = "def add(a, b): return a + b"
results, error_info = run_test(in_outs, test=test_code, timeout=10)
# results: [True, True], error_info: {}
# Standard input test
in_outs_stdin = {
"fn_name": None,
"inputs": ["5"],
"outputs": ["25"],
}
test_code_stdin = "n = int(input())\nprint(n * n)"
results, error_info = run_test(in_outs_stdin, test=test_code_stdin)