Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Turboderp org Exllamav2 TPContext

From Leeroopedia
Knowledge Sources
Domains Tensor Parallelism, Multi-GPU
Last Updated 2026-02-15 00:00 GMT

Overview

TPContext manages tensor-parallel distribution of model weights and activations across multiple GPUs, computing balanced splits for KV heads, intermediate dimensions, vocabulary, residual stream, and query heads.

Description

TPContext is the central coordination object for tensor parallelism in ExLlamaV2. It computes how to partition five different tensor dimensions across available GPUs:

  • BROADCAST_KV (0): Key/value head split for attention layers
  • BROADCAST_ID (1): Intermediate dimension split for MLP layers
  • BROADCAST_VC (2): Vocabulary dimension split for the language model head
  • BROADCAST_RS (3): Residual stream (hidden size) split
  • BROADCAST_Q (4): Query head split (derived from KV split times num_key_value_groups)

The constructor accepts a gpu_split parameter (list of floats representing GB available per GPU, or None for auto-detection). It calls define_split() which: 1. Converts the gpu_split to MB and uses integer_split() to partition num_key_value_heads across GPUs proportionally. 2. Subtracts the expected KV cache memory from available VRAM based on the cache type (FP16, 8-bit, Q8, Q6, or Q4). 3. Subtracts the memory footprint of attention layers proportionally. 4. Splits the intermediate dimension (multiples of 128), residual stream (multiples of 32), and vocabulary (multiples of 32) across the remaining VRAM.

Each split is stored as a list of (device_index, start, end) tuples. The finalize() method allocates pinned CPU memory for gather operations and creates the native C++ ext_tp_context handle with all split information and CUDA streams.

The class provides broadcast(), gather(), allgather(), and add_residual() operations that use the C++ extension for efficient cross-device data movement. wait_streams() synchronizes all device streams. get_sin_cos() retrieves per-device RoPE sinusoidal tables.

Scratch buffer allocation is managed through begin_scratch_alloc_tp(), get_scratch_slice_tp(), and get_scratch_slice_tp_bc() for temporary computation buffers on each device.

Usage

TPContext is created by ExLlamaV2 when loading a model with tensor parallelism enabled. Users pass a gpu_split list to the model loader specifying how much VRAM (in GB) each GPU should use, or None for automatic detection. The context is then used internally by attention and MLP layers to split and gather tensors during forward passes.

Code Reference

Source Location

Signature

# Broadcast type constants
BROADCAST_KV = 0   # Key/Value heads
BROADCAST_ID = 1   # Intermediate dimension (MLP)
BROADCAST_VC = 2   # Vocabulary columns
BROADCAST_RS = 3   # Residual stream
BROADCAST_Q = 4    # Query heads

class TPContext:

    model: ExLlamaV2
    kv_split: list[tuple[int, int, int]] | None
    id_split: list[tuple[int, int, int]] | None
    vc_split: list[tuple[int, int, int]] | None
    rs_split: list[tuple[int, int, int]] | None
    q_split: list[tuple[int, int, int]] | None
    pinned_temp: list[torch.Tensor] | None
    device: int | None
    all_devs: list[int | None] | None
    num_devices: int | None
    ext_tp_context: int | None
    sin: list[torch.Tensor | None] | None
    cos: list[torch.Tensor | None] | None

    def __init__(
        self,
        model: ExLlamaV2,
        gpu_split: list[float] | None,
        expect_cache_tokens: int = 0,
        expect_cache_base: type = None
    ): ...

Import

from exllamav2.tensor_p import TPContext, BROADCAST_KV, BROADCAST_ID, BROADCAST_VC, BROADCAST_RS, BROADCAST_Q

I/O Contract

Inputs (__init__)

Name Type Required Description
model ExLlamaV2 Yes The model instance this TP context is associated with
gpu_split list[float] or None Yes VRAM budget per GPU in GB; None triggers auto-detection via nvidia-smi/rocm-smi
expect_cache_tokens int No (default 0) Expected number of cache tokens; 0 defaults to max_seq_len * max_batch_size
expect_cache_base type or None No Cache class type to estimate per-element byte cost (FP16, 8-bit, Q8, Q6, or Q4)

Inputs (broadcast)

Name Type Required Description
buffer int Yes Index (0 or 1) of the pinned temporary buffer to use for staging
input_tensor torch.Tensor Yes The tensor to broadcast from the source device to all split devices
broadcast_type int Yes One of BROADCAST_KV, BROADCAST_ID, BROADCAST_VC, BROADCAST_RS, BROADCAST_Q
dim int No (default 1) Dimension multiplier for the split

Outputs

Name Type Description
broadcast() list[torch.Tensor] List of tensors, one per device in the split, containing the full broadcast data
gather() torch.Tensor Pinned CPU tensor containing the gathered (concatenated) result from all devices
allgather() list[torch.Tensor] List of full-size tensors, one per device in the broadcast split, after gather + re-broadcast
get_split(broadcast_type) list[tuple[int,int,int]] The (device, start, end) split tuples for the given broadcast type
get_sin_cos() tuple[list, list] Per-device sin and cos RoPE embedding tensors

Usage Examples

Basic Usage

from exllamav2 import ExLlamaV2, ExLlamaV2Config
from exllamav2.tensor_p import TPContext, BROADCAST_KV

config = ExLlamaV2Config("/path/to/model")
model = ExLlamaV2(config)

# Create TP context: 20 GB on GPU 0, 20 GB on GPU 1
tp_context = TPContext(model, gpu_split=[20.0, 20.0])

# Inspect the KV head split
for dev, a, b in tp_context.kv_split:
    print(f"GPU {dev}: KV heads [{a}:{b}]")

# Finalize (allocates pinned buffers and native handle)
tp_context.finalize()

Broadcasting a Tensor

from exllamav2.tensor_p import BROADCAST_ID
import torch

# Suppose hidden_states is on GPU 0 after layernorm
hidden_states = torch.randn(1, 128, 4096, device="cuda:0", dtype=torch.half)

# Broadcast to all devices in the intermediate-dimension split
split_tensors = tp_context.broadcast(
    buffer=0,
    input_tensor=hidden_states.view(-1, hidden_states.shape[-1]),
    broadcast_type=BROADCAST_ID
)
# split_tensors is a list with one tensor per GPU

Gathering Results

# After each GPU computes its portion of the MLP output:
gathered = tp_context.gather(
    buffer=1,
    inputs=partial_outputs,        # list of tensors, one per device
    broadcast_type=BROADCAST_ID
)
# gathered is a pinned CPU tensor with the full concatenated result

Key Methods

Method Description
__init__(model, gpu_split, ...) Computes balanced splits for KV, intermediate, vocabulary, residual, and query dimensions
define_split(gpu_split, ...) Core split computation: partitions dimensions based on available VRAM after subtracting cache and attention weight costs
finalize() Allocates pinned temporary buffers and creates the native C++ TP context handle
unload() Frees the native C++ TP context handle
all_devices() Returns a sorted list of all device indices used across KV, ID, and VC splits
get_split(broadcast_type) Returns the split tuple list for the given broadcast type constant
get_devs(broadcast_type) Returns just the device index list for the given broadcast type
broadcast(buffer, input_tensor, broadcast_type, dim) Copies input_tensor to all devices in the split via the C++ extension
gather(buffer, inputs, broadcast_type, dim) Gathers partial tensors from all devices into a pinned CPU tensor
allgather(buffer, inputs, broadcast_type_g, broadcast_type_b, dim) Gather from one split and re-broadcast to another split (all on GPU)
add_residual(target, source, broadcast_type, dim) Adds the corresponding slice of source to each target tensor in the split (per-device streams)
wait_streams() Synchronizes all CUDA streams across all active devices
get_sin_cos() Returns per-device sin and cos RoPE embedding tensors, lazily concatenated from device contexts
get_pinned(buffer, batch_size, q_len, dim) Returns a view into the pinned temporary buffer reshaped to (batch_size, q_len, dim)
begin_scratch_alloc_tp() Resets scratch allocation on all device contexts
get_scratch_slice_tp(rows, dtype, broadcast_type, dim) Allocates per-device scratch slices sized according to each device's portion of the split
get_scratch_slice_tp_bc(rows, dtype, broadcast_type, dim) Allocates per-device scratch slices all sized to the full broadcast dimension
reserve_scratch(scratch) Reserves scratch space on each device from a list of per-device byte counts

Broadcast Type Constants

Constant Value Description
BROADCAST_KV 0 Key/Value attention head split
BROADCAST_ID 1 Intermediate (MLP) dimension split
BROADCAST_VC 2 Vocabulary column split for the LM head
BROADCAST_RS 3 Residual stream (hidden_size) split
BROADCAST_Q 4 Query attention head split (KV split * num_key_value_groups)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment