Chapter 5: Attention Mechanism

This chapter dives into how attention actually executes during inference. The attention layer is where the KV cache meets the model — it writes new keys and values into paged memory, then reads them back to compute attention scores. We will trace the full path through nano-vllm’s Triton kernel and Flash Attention calls, then map it to vLLM’s pluggable backend system.

Before diving into code, walk through the core attention computation step by step:

Interactive: Self-Attention ComputationStep 1 / 5

Step 1: Query vector (Q)

The current token produces a query vector Q through a linear projection. Q represents "what am I looking for?" — it's the question this token asks of all previous tokens.

Q×KT÷ √dk→ softmax×V=output
Q (1×4)
0.80
0.20
0.50
0.10
K (4×4)
0.90
0.10
0.40
0.20
0.30
0.70
0.10
0.60
0.50
0.50
0.30
0.30
0.10
0.80
0.60
0.40

KV Cache Write: The Triton Kernel

Before attention can be computed, the newly generated K and V tensors must be stored into the paged KV cache. nano-vllm uses a custom Triton kernel for this:

nano-vllm
nanovllm/layers/attention.py
Attention layer with store_kvcache Triton kernel and prefill/decode dispatch logic.
@triton.jit
def store_kvcache_kernel(
    K, V, KCache, VCache, SlotMapping,
    head_dim: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    token_idx = tl.program_id(0)
    head_idx = tl.program_id(1)

    slot = tl.load(SlotMapping + token_idx)

    offsets = tl.arange(0, BLOCK_SIZE)
    mask = offsets < head_dim

    k = tl.load(K + token_idx * head_dim + offsets, mask=mask)
    v = tl.load(V + token_idx * head_dim + offsets, mask=mask)

    tl.store(KCache + slot * head_dim + offsets, k, mask=mask)
    tl.store(VCache + slot * head_dim + offsets, v, mask=mask)

The kernel is launched with a 2D grid: one program per (token, head) pair. Each program:

  1. Reads the slot_mapping for this token — the slot tells us where in the paged cache this token’s KV should live
  2. Loads the K and V vectors from the model’s output
  3. Stores them into the corresponding cache positions

This is the bridge between the model’s linear projections and the block-managed KV cache from Chapter 4. The slot_mapping is computed by the scheduler based on block allocation, so the kernel doesn’t need to know anything about block tables or page management.

Prefill Attention: Variable-Length Batching

During prefill, we process entire prompts at once. Different prompts in the batch have different lengths, so we need variable-length attention. nano-vllm uses flash_attn_varlen_func for this:

class Attention(nn.Module):
    def __init__(self, num_heads, head_dim, num_kv_heads, scale, sliding_window=None):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.num_kv_heads = num_kv_heads
        self.scale = scale
        self.sliding_window = sliding_window or (-1, -1)

    def forward(self, q, k, v, kv_cache, attn_metadata):
        # Step 1: Store K/V into paged cache
        store_kvcache(k, v, kv_cache, attn_metadata.slot_mapping)

        if attn_metadata.is_prefill:
            # Step 2a: Prefill — full causal attention over the prompt
            output = flash_attn_varlen_func(
                q, k, v,
                cu_seqlens_q=attn_metadata.cu_seqlens,
                cu_seqlens_k=attn_metadata.cu_seqlens,
                max_seqlen_q=attn_metadata.max_seqlen,
                max_seqlen_k=attn_metadata.max_seqlen,
                softmax_scale=self.scale,
                causal=True,
                window_size=self.sliding_window,
            )
        else:
            # Step 2b: Decode — attend to full KV cache
            output = flash_attn_with_kvcache(
                q, kv_cache[0], kv_cache[1],
                cache_seqlens=attn_metadata.seq_lens,
                block_table=attn_metadata.block_table,
                softmax_scale=self.scale,
                causal=True,
                window_size=self.sliding_window,
            )
        return output

The Attention.forward() method has two distinct paths:

  • Prefill path: Uses flash_attn_varlen_func with cu_seqlens (cumulative sequence lengths) to handle variable-length sequences in a single batch. Both Q and K/V come from the current input — this is standard self-attention over the full prompt.

  • Decode path: Uses flash_attn_with_kvcache which reads K/V directly from the paged cache via block_table. Only the new token’s Q is provided; the full history comes from cache.

The cu_seqlens tensor is key to variable-length batching. For a batch of prompts with lengths [5, 3, 8], cu_seqlens would be [0, 5, 8, 16]. Flash Attention uses this to know where each sequence starts and ends within the packed tensor.

Decode Attention: Reading from Paged Cache

During decode, the attention kernel must gather K/V from non-contiguous physical blocks. Step through this demo to see how the block table maps logical positions to scattered physical blocks:

Interactive: Paged Attention DecodeStep 1 / 7

Physical block layout

The KV cache is a pool of physical blocks scattered across GPU memory. Our sequence "The cat sat on the mat and then it looked at me" occupies blocks 0, 1, 2 — but they're not in order! Block 2 holds the first tokens, block 0 holds the middle, block 1 holds the end.

Block Table: logical → physical
[0] → #2
[1] → #0
[2] → #1
Physical KV Cache Blocks
#0
Thecatsaton
#1
thematandthen
#2
itlookedatme
#3
[OTHER][OTHER][OTHER][OTHER]
#4
[OTHER][OTHER][OTHER][OTHER]
#5
[OTHER][OTHER][OTHER][OTHER]
#6
[OTHER][OTHER][OTHER][OTHER]
#7
[OTHER][OTHER][OTHER][OTHER]

During decode, each request generates one token at a time. The query tensor has shape [batch_size, 1, num_heads, head_dim] — just one position per sequence. But the keys and values span the entire history, stored across potentially non-contiguous cache blocks.

flash_attn_with_kvcache handles this efficiently:

  • block_table maps each sequence to its list of physical block indices
  • cache_seqlens tells the kernel how many tokens each sequence has accumulated
  • The kernel reads K/V from scattered blocks as if they were contiguous

This is where paged attention pays off — sequences can grow without needing contiguous memory, and the attention kernel handles the indirection transparently.

Prefix Cache Attention

When prefix caching is enabled (Chapter 4), a prefill request may have some tokens already cached from a previous request with the same prefix. In this case, only the new tokens need Q computation, but K/V attention spans the full sequence:

cu_seqlens_q = [0, 3]       # only 3 new tokens to compute
cu_seqlens_k = [0, 103]     # but attend to all 103 tokens (100 cached + 3 new)

The key insight: cu_seqlens_q != cu_seqlens_k. The Q side only covers the suffix that wasn’t cached, while the K side covers the entire sequence including the cached prefix. Flash Attention’s variable-length interface supports this asymmetry natively.

Mapping to Production vLLM

Production vLLM abstracts the attention computation behind a pluggable backend system. Instead of directly calling Flash Attention functions, the Attention layer delegates to a backend:

vllm (production)
vllm/model_executor/layers/attention/attention.py
Attention layer — stores KV cache and dispatches to the selected backend implementation.
vllm (production)
vllm/v1/attention/backend.py
Abstract backend interface that all attention implementations must satisfy.

The backend is selected at startup based on hardware and configuration:

vllm (production)
vllm/v1/attention/selector.py
Auto-selects the best attention backend (FlashAttention, FlashInfer, Triton) based on platform and model.

vLLM ships with multiple backend implementations:

vllm (production)
vllm/v1/attention/backends/flash_attn.py
FlashAttention backend — the default on NVIDIA GPUs, wraps flash-attn library calls.
vllm (production)
vllm/v1/attention/backends/flashinfer.py
FlashInfer backend — alternative with optimized paged attention kernels and CUDA graph support.

Key differences from nano-vllm:

  • Pluggable backends — nano-vllm hardcodes Flash Attention calls; vLLM defines an abstract AttentionBackend interface so backends can be swapped without changing model code
  • Backend auto-selection — the selector picks the best backend based on GPU architecture, model type, and user preferences
  • FlashInfer — provides optimized paged attention kernels that can outperform Flash Attention for decode-heavy workloads, especially with CUDA graph capture
  • Triton backend — a pure-Triton fallback for platforms where Flash Attention isn’t available
  • Unified metadata — each backend defines its own AttentionMetadata class to carry the tensors it needs (block tables, sequence lengths, etc.)

The architecture means adding a new attention kernel (say, for a new GPU architecture) only requires implementing the backend interface — no changes to model code or the scheduler.

Summary

  • The store_kvcache kernel writes K/V into paged cache slots using a Triton kernel indexed by slot_mapping
  • Prefill uses flash_attn_varlen_func with cu_seqlens for variable-length batched self-attention
  • Decode uses flash_attn_with_kvcache to read K/V from the paged block table, computing attention for one new token per sequence
  • Prefix caching exploits cu_seqlens_q != cu_seqlens_k to skip recomputing attention for cached prefixes
  • Production vLLM wraps this behind a pluggable backend system (FlashAttention, FlashInfer, Triton) selected automatically at startup