Implementation:Turboderp org Exllamav2 TPContext
| Knowledge Sources | |
|---|---|
| Domains | Tensor Parallelism, Multi-GPU |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
TPContext manages tensor-parallel distribution of model weights and activations across multiple GPUs, computing balanced splits for KV heads, intermediate dimensions, vocabulary, residual stream, and query heads.
Description
TPContext is the central coordination object for tensor parallelism in ExLlamaV2. It computes how to partition five different tensor dimensions across available GPUs:
- BROADCAST_KV (0): Key/value head split for attention layers
- BROADCAST_ID (1): Intermediate dimension split for MLP layers
- BROADCAST_VC (2): Vocabulary dimension split for the language model head
- BROADCAST_RS (3): Residual stream (hidden size) split
- BROADCAST_Q (4): Query head split (derived from KV split times num_key_value_groups)
The constructor accepts a gpu_split parameter (list of floats representing GB available per GPU, or None for auto-detection). It calls define_split() which: 1. Converts the gpu_split to MB and uses integer_split() to partition num_key_value_heads across GPUs proportionally. 2. Subtracts the expected KV cache memory from available VRAM based on the cache type (FP16, 8-bit, Q8, Q6, or Q4). 3. Subtracts the memory footprint of attention layers proportionally. 4. Splits the intermediate dimension (multiples of 128), residual stream (multiples of 32), and vocabulary (multiples of 32) across the remaining VRAM.
Each split is stored as a list of (device_index, start, end) tuples. The finalize() method allocates pinned CPU memory for gather operations and creates the native C++ ext_tp_context handle with all split information and CUDA streams.
The class provides broadcast(), gather(), allgather(), and add_residual() operations that use the C++ extension for efficient cross-device data movement. wait_streams() synchronizes all device streams. get_sin_cos() retrieves per-device RoPE sinusoidal tables.
Scratch buffer allocation is managed through begin_scratch_alloc_tp(), get_scratch_slice_tp(), and get_scratch_slice_tp_bc() for temporary computation buffers on each device.
Usage
TPContext is created by ExLlamaV2 when loading a model with tensor parallelism enabled. Users pass a gpu_split list to the model loader specifying how much VRAM (in GB) each GPU should use, or None for automatic detection. The context is then used internally by attention and MLP layers to split and gather tensors during forward passes.
Code Reference
Source Location
- Repository: Turboderp_org_Exllamav2
- File: exllamav2/tensor_p.py
- Lines: 21-451
Signature
# Broadcast type constants
BROADCAST_KV = 0 # Key/Value heads
BROADCAST_ID = 1 # Intermediate dimension (MLP)
BROADCAST_VC = 2 # Vocabulary columns
BROADCAST_RS = 3 # Residual stream
BROADCAST_Q = 4 # Query heads
class TPContext:
model: ExLlamaV2
kv_split: list[tuple[int, int, int]] | None
id_split: list[tuple[int, int, int]] | None
vc_split: list[tuple[int, int, int]] | None
rs_split: list[tuple[int, int, int]] | None
q_split: list[tuple[int, int, int]] | None
pinned_temp: list[torch.Tensor] | None
device: int | None
all_devs: list[int | None] | None
num_devices: int | None
ext_tp_context: int | None
sin: list[torch.Tensor | None] | None
cos: list[torch.Tensor | None] | None
def __init__(
self,
model: ExLlamaV2,
gpu_split: list[float] | None,
expect_cache_tokens: int = 0,
expect_cache_base: type = None
): ...
Import
from exllamav2.tensor_p import TPContext, BROADCAST_KV, BROADCAST_ID, BROADCAST_VC, BROADCAST_RS, BROADCAST_Q
I/O Contract
Inputs (__init__)
| Name | Type | Required | Description |
|---|---|---|---|
| model | ExLlamaV2 | Yes | The model instance this TP context is associated with |
| gpu_split | list[float] or None | Yes | VRAM budget per GPU in GB; None triggers auto-detection via nvidia-smi/rocm-smi |
| expect_cache_tokens | int | No (default 0) | Expected number of cache tokens; 0 defaults to max_seq_len * max_batch_size |
| expect_cache_base | type or None | No | Cache class type to estimate per-element byte cost (FP16, 8-bit, Q8, Q6, or Q4) |
Inputs (broadcast)
| Name | Type | Required | Description |
|---|---|---|---|
| buffer | int | Yes | Index (0 or 1) of the pinned temporary buffer to use for staging |
| input_tensor | torch.Tensor | Yes | The tensor to broadcast from the source device to all split devices |
| broadcast_type | int | Yes | One of BROADCAST_KV, BROADCAST_ID, BROADCAST_VC, BROADCAST_RS, BROADCAST_Q |
| dim | int | No (default 1) | Dimension multiplier for the split |
Outputs
| Name | Type | Description |
|---|---|---|
| broadcast() | list[torch.Tensor] | List of tensors, one per device in the split, containing the full broadcast data |
| gather() | torch.Tensor | Pinned CPU tensor containing the gathered (concatenated) result from all devices |
| allgather() | list[torch.Tensor] | List of full-size tensors, one per device in the broadcast split, after gather + re-broadcast |
| get_split(broadcast_type) | list[tuple[int,int,int]] | The (device, start, end) split tuples for the given broadcast type |
| get_sin_cos() | tuple[list, list] | Per-device sin and cos RoPE embedding tensors |
Usage Examples
Basic Usage
from exllamav2 import ExLlamaV2, ExLlamaV2Config
from exllamav2.tensor_p import TPContext, BROADCAST_KV
config = ExLlamaV2Config("/path/to/model")
model = ExLlamaV2(config)
# Create TP context: 20 GB on GPU 0, 20 GB on GPU 1
tp_context = TPContext(model, gpu_split=[20.0, 20.0])
# Inspect the KV head split
for dev, a, b in tp_context.kv_split:
print(f"GPU {dev}: KV heads [{a}:{b}]")
# Finalize (allocates pinned buffers and native handle)
tp_context.finalize()
Broadcasting a Tensor
from exllamav2.tensor_p import BROADCAST_ID
import torch
# Suppose hidden_states is on GPU 0 after layernorm
hidden_states = torch.randn(1, 128, 4096, device="cuda:0", dtype=torch.half)
# Broadcast to all devices in the intermediate-dimension split
split_tensors = tp_context.broadcast(
buffer=0,
input_tensor=hidden_states.view(-1, hidden_states.shape[-1]),
broadcast_type=BROADCAST_ID
)
# split_tensors is a list with one tensor per GPU
Gathering Results
# After each GPU computes its portion of the MLP output:
gathered = tp_context.gather(
buffer=1,
inputs=partial_outputs, # list of tensors, one per device
broadcast_type=BROADCAST_ID
)
# gathered is a pinned CPU tensor with the full concatenated result
Key Methods
| Method | Description |
|---|---|
| __init__(model, gpu_split, ...) | Computes balanced splits for KV, intermediate, vocabulary, residual, and query dimensions |
| define_split(gpu_split, ...) | Core split computation: partitions dimensions based on available VRAM after subtracting cache and attention weight costs |
| finalize() | Allocates pinned temporary buffers and creates the native C++ TP context handle |
| unload() | Frees the native C++ TP context handle |
| all_devices() | Returns a sorted list of all device indices used across KV, ID, and VC splits |
| get_split(broadcast_type) | Returns the split tuple list for the given broadcast type constant |
| get_devs(broadcast_type) | Returns just the device index list for the given broadcast type |
| broadcast(buffer, input_tensor, broadcast_type, dim) | Copies input_tensor to all devices in the split via the C++ extension |
| gather(buffer, inputs, broadcast_type, dim) | Gathers partial tensors from all devices into a pinned CPU tensor |
| allgather(buffer, inputs, broadcast_type_g, broadcast_type_b, dim) | Gather from one split and re-broadcast to another split (all on GPU) |
| add_residual(target, source, broadcast_type, dim) | Adds the corresponding slice of source to each target tensor in the split (per-device streams) |
| wait_streams() | Synchronizes all CUDA streams across all active devices |
| get_sin_cos() | Returns per-device sin and cos RoPE embedding tensors, lazily concatenated from device contexts |
| get_pinned(buffer, batch_size, q_len, dim) | Returns a view into the pinned temporary buffer reshaped to (batch_size, q_len, dim) |
| begin_scratch_alloc_tp() | Resets scratch allocation on all device contexts |
| get_scratch_slice_tp(rows, dtype, broadcast_type, dim) | Allocates per-device scratch slices sized according to each device's portion of the split |
| get_scratch_slice_tp_bc(rows, dtype, broadcast_type, dim) | Allocates per-device scratch slices all sized to the full broadcast dimension |
| reserve_scratch(scratch) | Reserves scratch space on each device from a list of per-device byte counts |
Broadcast Type Constants
| Constant | Value | Description |
|---|---|---|
| BROADCAST_KV | 0 | Key/Value attention head split |
| BROADCAST_ID | 1 | Intermediate (MLP) dimension split |
| BROADCAST_VC | 2 | Vocabulary column split for the LM head |
| BROADCAST_RS | 3 | Residual stream (hidden_size) split |
| BROADCAST_Q | 4 | Query attention head split (KV split * num_key_value_groups) |