Principle:Pytorch Serve Tensor Parallel LLM Architecture
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Model_Architecture |
| Last Updated | 2026-02-13 18:52 GMT |
Overview
Tensor Parallel LLM Architecture is the principle of distributing large language model computation across multiple GPUs by sharding individual weight tensors — using column-wise and row-wise parallelism for attention and feed-forward layers — to enable serving models that exceed single-GPU memory capacity.
Description
As large language models grow to tens or hundreds of billions of parameters, they frequently exceed the memory capacity of a single GPU. Tensor parallelism (TP) addresses this by splitting individual weight matrices across multiple devices, allowing each GPU to hold and compute only a shard of each layer.
The two fundamental sharding strategies are:
- Column-wise parallelism — A weight matrix W of shape
(d_in, d_out)is split along the output dimension (columns), so each GPU holds(d_in, d_out / N). Each GPU computes its portion of the output independently, requiring no communication during the matrix multiplication itself. The partial outputs are concatenated or reduced afterward.
- Row-wise parallelism — A weight matrix W is split along the input dimension (rows), so each GPU holds
(d_in / N, d_out). Each GPU processes a shard of the input and produces a partial result. An all-reduce operation then sums the partial results across GPUs.
In transformer architectures, these strategies are applied systematically:
- Attention layers — The Q, K, V projection matrices are column-parallel (each GPU computes attention for a subset of heads). The output projection is row-parallel.
- Feed-forward layers — The first linear layer is column-parallel (split the intermediate dimension). The second linear layer is row-parallel (reduce back to the model dimension).
import torch
import torch.distributed as dist
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column-wise tensor parallelism."""
def __init__(self, in_features, out_features, world_size, rank):
super().__init__()
assert out_features % world_size == 0
self.local_out = out_features // world_size
self.weight = torch.nn.Parameter(
torch.randn(in_features, self.local_out)
)
def forward(self, x):
# Each GPU computes its column shard independently
return x @ self.weight
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row-wise tensor parallelism."""
def __init__(self, in_features, out_features, world_size, rank):
super().__init__()
assert in_features % world_size == 0
self.local_in = in_features // world_size
self.weight = torch.nn.Parameter(
torch.randn(self.local_in, out_features)
)
def forward(self, x_shard):
local_output = x_shard @ self.weight
# All-reduce to sum partial results across GPUs
dist.all_reduce(local_output, op=dist.ReduceOp.SUM)
return local_output
Usage
Apply Tensor Parallel LLM Architecture when:
- The model size exceeds the memory capacity of a single GPU and must be distributed across multiple devices.
- Low-latency inference is required — tensor parallelism reduces per-GPU computation and can lower latency compared to pipeline parallelism.
- Serving models such as LLaMA 2 70B, GPT-3, or similarly large architectures where each transformer layer is too large for a single device.
- A checkpoint conversion step is needed to reshape single-device model weights into per-rank shards for distributed loading.
Theoretical Basis
Tensor parallelism is grounded in the algebraic property that matrix multiplication can be decomposed along either the input or output dimension.
For column-wise parallelism, a matrix multiply Y = X * W where W = [W_1, W_2, ..., W_N] (split along columns) decomposes into:
Y = [X * W_1, X * W_2, ..., X * W_N]
Each partition is computed independently, and the results are concatenated. This requires no inter-GPU communication during computation.
For row-wise parallelism, where W is split along rows and X is correspondingly partitioned as [X_1, X_2, ..., X_N]:
Y = X_1 * W_1 + X_2 * W_2 + ... + X_N * W_N
Each GPU computes a partial sum, and an all-reduce operation produces the final result. The communication cost is O(d_out) per token per layer.
By alternating column-parallel and row-parallel layers within each transformer block, the architecture requires only two all-reduce operations per transformer layer (one after attention, one after the feed-forward network), minimizing communication overhead while enabling full distribution of the model across devices.