Chapter 7: Tensor Parallelism
This chapter covers how model weights and computation are split across multiple GPUs. Tensor parallelism (TP) is the primary strategy for serving models that don’t fit on a single GPU — it partitions each layer’s weight matrices so that each GPU computes a slice of the result, then combines them. We will trace the parallel linear layers in nano-vllm, then map to vLLM’s distributed execution infrastructure.
Column Parallel Linear: Shard by Output Dimension
The simplest form of tensor parallelism splits a weight matrix along its output (column) dimension. If a linear layer has shape [in_features, out_features] and we have N GPUs, each GPU holds a shard of shape [in_features, out_features / N]:
class ColumnParallelLinear(nn.Module):
def __init__(self, input_size, output_size, bias=False):
super().__init__()
self.tp_size = get_tensor_parallel_size()
self.tp_rank = get_tensor_parallel_rank()
assert output_size % self.tp_size == 0
self.output_size_per_partition = output_size // self.tp_size
self.weight = nn.Parameter(
torch.empty(self.output_size_per_partition, input_size)
)
def forward(self, x):
return F.linear(x, self.weight)
Each GPU receives the full input x and multiplies it by its shard of the weight. The output is a slice of the full result — no communication needed at this stage. This is used for projections where we want the output split across GPUs, like the first projection in an MLP.
Row Parallel Linear: Shard by Input Dimension + All-Reduce
Row parallelism is the complement: the weight is split along the input (row) dimension. Each GPU holds shape [in_features / N, out_features] and receives a slice of the input:
class RowParallelLinear(nn.Module):
def __init__(self, input_size, output_size, bias=False):
super().__init__()
self.tp_size = get_tensor_parallel_size()
self.tp_rank = get_tensor_parallel_rank()
assert input_size % self.tp_size == 0
self.input_size_per_partition = input_size // self.tp_size
self.weight = nn.Parameter(
torch.empty(output_size, self.input_size_per_partition)
)
def forward(self, x):
output = F.linear(x, self.weight)
# All-reduce: sum partial results across GPUs
if self.tp_size > 1:
torch.distributed.all_reduce(output)
return output
The key difference: after the matmul, each GPU has a partial result. The all_reduce sums these partials across all GPUs so every GPU ends up with the correct full output. This is the only communication point in a standard TP layer pair.
The typical pattern in a Transformer is:
- Column parallel for the first projection (e.g., QKV proj, gate+up proj) — splits output, no communication
- Row parallel for the second projection (e.g., output proj, down proj) — recombines with all-reduce
This way, each layer pair requires exactly one all-reduce operation.
QKV Parallel Linear: Handling GQA Head Counts
Grouped-Query Attention (GQA) uses fewer K/V heads than Q heads. The QKVParallelLinear handles this asymmetry:
class QKVParallelLinear(nn.Module):
def __init__(self, hidden_size, head_dim, num_heads, num_kv_heads):
super().__init__()
self.tp_size = get_tensor_parallel_size()
self.tp_rank = get_tensor_parallel_rank()
self.num_heads_per_partition = num_heads // self.tp_size
self.num_kv_heads_per_partition = num_kv_heads // self.tp_size
# Q gets more columns than K or V
q_size = self.num_heads_per_partition * head_dim
kv_size = self.num_kv_heads_per_partition * head_dim
total_size = q_size + 2 * kv_size
self.weight = nn.Parameter(
torch.empty(total_size, hidden_size)
)
def forward(self, x):
return F.linear(x, self.weight)
With GQA, if a model has 32 Q heads and 8 KV heads on TP=4, each GPU gets 8 Q heads but only 2 KV heads. The weight is sized accordingly — Q gets a larger slice than K or V. The weight loader must understand this asymmetry when distributing checkpoint weights.
Merged Column Parallel Linear: Gate + Up Fusion
The MLP’s gate and up projections are fused into a single MergedColumnParallelLinear :
class MergedColumnParallelLinear(nn.Module):
def __init__(self, input_size, output_sizes, bias=False):
super().__init__()
self.tp_size = get_tensor_parallel_size()
self.tp_rank = get_tensor_parallel_rank()
self.output_sizes_per_partition = [
size // self.tp_size for size in output_sizes
]
total = sum(self.output_sizes_per_partition)
self.weight = nn.Parameter(torch.empty(total, input_size))
def forward(self, x):
return F.linear(x, self.weight)
This is column parallel applied to multiple logical projections stacked together. Each GPU holds [gate_size/N + up_size/N, hidden_size]. The weight loader stacks the gate and up checkpoint weights into the correct positions within each GPU’s shard.
Vocab Parallel Embedding and LM Head
The vocabulary embedding and output head are also sharded:
class VocabParallelEmbedding(nn.Module):
def __init__(self, vocab_size, hidden_size):
super().__init__()
self.tp_size = get_tensor_parallel_size()
self.tp_rank = get_tensor_parallel_rank()
self.vocab_size_per_partition = vocab_size // self.tp_size
self.vocab_start = self.tp_rank * self.vocab_size_per_partition
self.vocab_end = self.vocab_start + self.vocab_size_per_partition
self.weight = nn.Parameter(
torch.empty(self.vocab_size_per_partition, hidden_size)
)
def forward(self, input_ids):
# Mask out-of-range tokens, embed, then all-reduce
mask = (input_ids >= self.vocab_start) & (input_ids < self.vocab_end)
masked_ids = (input_ids - self.vocab_start) * mask
output = F.embedding(masked_ids, self.weight) * mask.unsqueeze(-1)
if self.tp_size > 1:
torch.distributed.all_reduce(output)
return output
Each GPU holds a slice of the vocabulary. For a given token ID, only the GPU whose slice contains that ID produces a non-zero embedding — the rest contribute zeros. The all-reduce sums them so every GPU gets the correct embedding vector.
The ParallelLMHead works similarly but in reverse — each GPU computes logits for its vocabulary slice, then the results are gathered to rank 0 for sampling:
class ParallelLMHead(nn.Module):
def __init__(self, vocab_size, hidden_size):
super().__init__()
self.tp_size = get_tensor_parallel_size()
self.tp_rank = get_tensor_parallel_rank()
self.vocab_size_per_partition = vocab_size // self.tp_size
self.weight = nn.Parameter(
torch.empty(self.vocab_size_per_partition, hidden_size)
)
def forward(self, hidden_states):
logits = F.linear(hidden_states, self.weight)
if self.tp_size > 1:
# Gather all partitions to rank 0 for sampling
gathered = [torch.empty_like(logits) for _ in range(self.tp_size)]
torch.distributed.all_gather(gathered, logits)
logits = torch.cat(gathered, dim=-1)
return logits
Inter-Process Communication
nano-vllm uses shared memory for TP communication between processes on the same node:
# Shared memory setup for TP communication
def _init_shared_memory(self):
if self.tp_size <= 1:
return
# Create shared memory buffers for all-reduce
shm_name = f"nano_vllm_tp_{os.getpid()}"
buf_size = self.max_num_tokens * self.hidden_size * 2 # fp16
self.shm = shared_memory.SharedMemory(
name=shm_name, create=(self.tp_rank == 0), size=buf_size
)
The shared memory setup creates memory buffers that all TP worker processes can access directly, avoiding the overhead of socket-based communication for same-node transfers.
Mapping to Production vLLM
Production vLLM has a much more sophisticated distributed execution system:
Key differences from nano-vllm:
- NCCL communication — production vLLM uses NCCL for GPU-to-GPU all-reduce, which is significantly faster than shared memory for large tensors. Custom all-reduce kernels further optimize small-message cases.
- Multiple parallelism dimensions — the parallel_state module manages not just TP, but also pipeline parallelism (PP), data parallelism (DP), expert parallelism (EP), and context parallelism (CP). Each has its own process group.
- Multi-node support — the Ray executor distributes workers across machines, while the multiproc executor handles same-node TP with lower overhead.
- Quantization-aware sharding — the parallel linear layers support GPTQ, AWQ, FP8, and other quantization formats, correctly sharding quantized weight matrices and their associated scales.
- Async all-reduce — communication can overlap with computation to hide latency, especially important for pipeline parallelism.
The core idea is identical: column parallel splits output dimensions, row parallel splits input dimensions and all-reduces. Everything else is optimization on top of this pattern.
Summary
- ColumnParallelLinear shards weights by output dimension — each GPU computes a slice, no communication needed
- RowParallelLinear shards by input dimension and uses all-reduce to sum partial results across GPUs
- QKVParallelLinear handles GQA’s asymmetric head counts, giving Q more columns than K/V per GPU
- VocabParallelEmbedding splits the vocabulary across GPUs with masking and all-reduce; ParallelLMHead gathers logits to rank 0
- nano-vllm uses shared memory for same-node TP; production vLLM uses NCCL with custom kernels and supports multi-node distribution via Ray
- The column-parallel + row-parallel pairing ensures each Transformer layer needs exactly one all-reduce operation