Chapter 5: Attention Mechanism
This chapter dives into how attention actually executes during inference. The attention layer is where the KV cache meets the model — it writes new keys and values into paged memory, then reads them back to compute attention scores. We will trace the full path through nano-vllm’s Triton kernel and Flash Attention calls, then map it to vLLM’s pluggable backend system.
Before diving into code, walk through the core attention computation step by step:
Step 1: Query vector (Q)
The current token produces a query vector Q through a linear projection. Q represents "what am I looking for?" — it's the question this token asks of all previous tokens.
KV Cache Write: The Triton Kernel
Before attention can be computed, the newly generated K and V tensors must be stored into the paged KV cache. nano-vllm uses a custom Triton kernel for this:
@triton.jit
def store_kvcache_kernel(
K, V, KCache, VCache, SlotMapping,
head_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
token_idx = tl.program_id(0)
head_idx = tl.program_id(1)
slot = tl.load(SlotMapping + token_idx)
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < head_dim
k = tl.load(K + token_idx * head_dim + offsets, mask=mask)
v = tl.load(V + token_idx * head_dim + offsets, mask=mask)
tl.store(KCache + slot * head_dim + offsets, k, mask=mask)
tl.store(VCache + slot * head_dim + offsets, v, mask=mask)
The kernel is launched with a 2D grid: one program per (token, head) pair. Each program:
- Reads the
slot_mappingfor this token — the slot tells us where in the paged cache this token’s KV should live - Loads the K and V vectors from the model’s output
- Stores them into the corresponding cache positions
This is the bridge between the model’s linear projections and the block-managed KV cache from Chapter 4. The slot_mapping is computed by the scheduler based on block allocation, so the kernel doesn’t need to know anything about block tables or page management.
Prefill Attention: Variable-Length Batching
During prefill, we process entire prompts at once. Different prompts in the batch have different lengths, so we need variable-length attention. nano-vllm uses flash_attn_varlen_func for this:
class Attention(nn.Module):
def __init__(self, num_heads, head_dim, num_kv_heads, scale, sliding_window=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.num_kv_heads = num_kv_heads
self.scale = scale
self.sliding_window = sliding_window or (-1, -1)
def forward(self, q, k, v, kv_cache, attn_metadata):
# Step 1: Store K/V into paged cache
store_kvcache(k, v, kv_cache, attn_metadata.slot_mapping)
if attn_metadata.is_prefill:
# Step 2a: Prefill — full causal attention over the prompt
output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=attn_metadata.cu_seqlens,
cu_seqlens_k=attn_metadata.cu_seqlens,
max_seqlen_q=attn_metadata.max_seqlen,
max_seqlen_k=attn_metadata.max_seqlen,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
)
else:
# Step 2b: Decode — attend to full KV cache
output = flash_attn_with_kvcache(
q, kv_cache[0], kv_cache[1],
cache_seqlens=attn_metadata.seq_lens,
block_table=attn_metadata.block_table,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
)
return output
The Attention.forward() method has two distinct paths:
-
Prefill path: Uses
flash_attn_varlen_funcwithcu_seqlens(cumulative sequence lengths) to handle variable-length sequences in a single batch. Both Q and K/V come from the current input — this is standard self-attention over the full prompt. -
Decode path: Uses
flash_attn_with_kvcachewhich reads K/V directly from the paged cache viablock_table. Only the new token’s Q is provided; the full history comes from cache.
The cu_seqlens tensor is key to variable-length batching. For a batch of prompts with lengths [5, 3, 8], cu_seqlens would be [0, 5, 8, 16]. Flash Attention uses this to know where each sequence starts and ends within the packed tensor.
Decode Attention: Reading from Paged Cache
During decode, the attention kernel must gather K/V from non-contiguous physical blocks. Step through this demo to see how the block table maps logical positions to scattered physical blocks:
Physical block layout
The KV cache is a pool of physical blocks scattered across GPU memory. Our sequence "The cat sat on the mat and then it looked at me" occupies blocks 0, 1, 2 — but they're not in order! Block 2 holds the first tokens, block 0 holds the middle, block 1 holds the end.
During decode, each request generates one token at a time. The query tensor has shape [batch_size, 1, num_heads, head_dim] — just one position per sequence. But the keys and values span the entire history, stored across potentially non-contiguous cache blocks.
flash_attn_with_kvcache handles this efficiently:
block_tablemaps each sequence to its list of physical block indicescache_seqlenstells the kernel how many tokens each sequence has accumulated- The kernel reads K/V from scattered blocks as if they were contiguous
This is where paged attention pays off — sequences can grow without needing contiguous memory, and the attention kernel handles the indirection transparently.
Prefix Cache Attention
When prefix caching is enabled (Chapter 4), a prefill request may have some tokens already cached from a previous request with the same prefix. In this case, only the new tokens need Q computation, but K/V attention spans the full sequence:
cu_seqlens_q = [0, 3] # only 3 new tokens to compute
cu_seqlens_k = [0, 103] # but attend to all 103 tokens (100 cached + 3 new)
The key insight: cu_seqlens_q != cu_seqlens_k. The Q side only covers the suffix that wasn’t cached, while the K side covers the entire sequence including the cached prefix. Flash Attention’s variable-length interface supports this asymmetry natively.
Mapping to Production vLLM
Production vLLM abstracts the attention computation behind a pluggable backend system. Instead of directly calling Flash Attention functions, the Attention layer delegates to a backend:
The backend is selected at startup based on hardware and configuration:
vLLM ships with multiple backend implementations:
Key differences from nano-vllm:
- Pluggable backends — nano-vllm hardcodes Flash Attention calls; vLLM defines an abstract
AttentionBackendinterface so backends can be swapped without changing model code - Backend auto-selection — the selector picks the best backend based on GPU architecture, model type, and user preferences
- FlashInfer — provides optimized paged attention kernels that can outperform Flash Attention for decode-heavy workloads, especially with CUDA graph capture
- Triton backend — a pure-Triton fallback for platforms where Flash Attention isn’t available
- Unified metadata — each backend defines its own
AttentionMetadataclass to carry the tensors it needs (block tables, sequence lengths, etc.)
The architecture means adding a new attention kernel (say, for a new GPU architecture) only requires implementing the backend interface — no changes to model code or the scheduler.
Summary
- The store_kvcache kernel writes K/V into paged cache slots using a Triton kernel indexed by
slot_mapping - Prefill uses
flash_attn_varlen_funcwithcu_seqlensfor variable-length batched self-attention - Decode uses
flash_attn_with_kvcacheto read K/V from the paged block table, computing attention for one new token per sequence - Prefix caching exploits
cu_seqlens_q != cu_seqlens_kto skip recomputing attention for cached prefixes - Production vLLM wraps this behind a pluggable backend system (FlashAttention, FlashInfer, Triton) selected automatically at startup