Heuristic:Huggingface Open r1 vLLM GPU Allocation
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Infrastructure, Evaluation |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Heuristic for computing the correct number of GPUs for vLLM tensor parallelism based on model attention head count and parameter count.
Description
vLLM enforces a hard constraint that the number of attention heads must be evenly divisible by the number of GPUs used for tensor parallelism, and additionally that 64 must be divisible by the number of GPUs. Open-R1 implements an automatic GPU count calculator that starts from the maximum available GPUs (default 8) and decrements until both divisibility constraints are met. Additionally, models with 30B+ parameters automatically enable tensor parallelism, while smaller models currently default to 2 GPUs as a temporary cluster-capacity workaround. A fallback heuristic extracts model parameter count from the repository ID string when safetensors metadata is unavailable.
Usage
Use this heuristic when configuring vLLM for model evaluation or any inference task that uses tensor parallelism. Apply when you need to determine how many GPUs to request for a Slurm evaluation job.
The Insight (Rule of Thumb)
- Action: Use
get_gpu_count_for_vllm()to compute GPU count. Start from 8 GPUs and decrement untilnum_heads % num_gpus == 0 AND 64 % num_gpus == 0. - Value:
- Models >= 30B parameters: enable tensor parallelism with computed GPU count.
- Models < 30B parameters: use 2 GPUs (temporary hack due to cluster capacity).
- Valid GPU counts (satisfying 64 divisibility): 1, 2, 4, 8 (since 64 must be divisible by num_gpus).
- Trade-off: Using fewer GPUs saves cluster resources but may cause OOM for large models. The 30B threshold is a rough heuristic; some models near this boundary may need manual adjustment.
Reasoning
vLLM tensor parallelism constraint: vLLM distributes attention heads across GPUs. If num_heads is not divisible by num_gpus, the distribution is impossible and vLLM raises an error. The additional constraint that 64 % num_gpus == 0 comes from vLLM's internal KV cache block size optimization.
30B threshold: Models under 30B parameters generally fit on 2 GPUs with enough headroom for KV cache. Models at or above 30B need more GPUs both for weight sharding and KV cache memory.
Fallback parameter extraction: When safetensors metadata is unavailable (e.g., private models, models without safetensors), the system parses the repo ID for patterns like "7b", "1.5b", "8x7b" (MoE models). This regex-based fallback extracts the largest matching number. If no pattern is found, it returns -1 (which will not trigger tensor parallelism).
Code Evidence
GPU count computation from src/open_r1/utils/hub.py:121-132:
def get_gpu_count_for_vllm(model_name: str, revision: str = "main", num_gpus: int = 8) -> int:
"""vLLM enforces a constraint that the number of attention heads must be
divisible by the number of GPUs and 64 must be divisible by the number of GPUs."""
config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
num_heads = config.num_attention_heads
while num_heads % num_gpus != 0 or 64 % num_gpus != 0:
logger.info(f"Reducing num_gpus from {num_gpus} to {num_gpus - 1}")
num_gpus -= 1
return num_gpus
30B threshold and 2-GPU hack from src/open_r1/utils/evaluation.py:77-83:
# For large models >= 30b params, we need to shard them across GPUs to avoid OOM
num_gpus = get_gpu_count_for_vllm(model_name, model_revision)
if get_param_count_from_repo_id(model_name) >= 30_000_000_000:
tensor_parallel = True
else:
num_gpus = 2 # Hack while cluster is full
tensor_parallel = False
Fallback parameter extraction from src/open_r1/utils/hub.py:89-118:
def get_param_count_from_repo_id(repo_id: str) -> int:
"""Get model param counts from safetensors metadata or find patterns like
42m, 1.5b, 0.5m or products like 8x7b in a repo ID."""
try:
metadata = get_safetensors_metadata(repo_id)
return list(metadata.parameter_count.values())[0]
except Exception:
pattern = r"((\d+(\.\d+)?)(x(\d+(\.\d+)?))?)([bm])"
matches = re.findall(pattern, repo_id.lower())
# ... fallback parsing