Chapter 4: KV Cache & Block Management
PagedAttention is the core innovation that makes vLLM fast. It borrows the idea of virtual memory from operating systems: instead of allocating one contiguous chunk of GPU memory per request, the KV cache is split into fixed-size blocks that can be allocated, freed, and reused independently. This chapter walks through the block management system — from nano-vllm’s 112-line implementation to vLLM’s production block pool with prefix caching.
Why Paged KV Cache?
Without paging, each request pre-allocates a contiguous KV cache buffer for its maximum possible length. A request with max_tokens=2048 reserves memory for 2048 tokens even if it only generates 50. This leads to massive internal fragmentation — most of the reserved memory sits unused.
PagedAttention fixes this by allocating memory in small, fixed-size blocks (typically 16 tokens each). A request that has generated 50 tokens uses only 4 blocks (64 token slots), not 128 blocks (2048 slots). Freed blocks are immediately available to other requests.
The analogy to OS virtual memory is direct:
| OS Concept | KV Cache Analog |
|---|---|
| Virtual page | Logical block (position in sequence) |
| Physical page frame | Physical block (GPU memory slot) |
| Page table | Block table (per-sequence mapping) |
| Free frame list | Free block list |
| Page content hash | Block hash (for prefix caching) |
nano-vllm’s BlockManager
nano-vllm’s block manager is compact but covers all the essential concepts:
The Block Class
class Block:
def __init__(self, block_id: int, block_size: int):
self.block_id = block_id
self.block_size = block_size
self.ref_count = 0
self.token_ids: List[int] = []
@property
def is_full(self):
return len(self.token_ids) >= self.block_size
def compute_hash(self):
return hash(tuple(self.token_ids))
Each block has:
block_id— its index in the physical block pool (used in block tables)block_size— how many tokens it can hold (fixed, e.g., 16)ref_count— how many sequences reference this block (enables sharing for prefix caching)token_ids— the actual tokens stored in this block (used for hash computation)
Allocation and Deallocation
class BlockManager:
def __init__(self, cache_config):
self.block_size = cache_config.block_size
self.num_blocks = cache_config.num_gpu_blocks
self.free_blocks: List[Block] = [
Block(i, self.block_size) for i in range(self.num_blocks)
]
self.block_hash_map: Dict[int, Block] = {}
def can_allocate(self, num_blocks: int) -> bool:
return len(self.free_blocks) >= num_blocks
def allocate(self, seq: Sequence, num_blocks: int):
for _ in range(num_blocks):
block = self.free_blocks.pop()
block.ref_count = 1
seq.block_table.append(block.block_id)
def deallocate(self, seq: Sequence):
for block_id in seq.block_table:
block = self.blocks[block_id]
block.ref_count -= 1
if block.ref_count == 0:
block.token_ids.clear()
self.free_blocks.append(block)
seq.block_table.clear()
Try the interactive demo below to see how blocks are allocated and freed as requests come and go:
Initial State
All 16 blocks are free. Think of this like an empty parking lot — every spot is available.
The free list is a simple Python list used as a stack. Allocation pops blocks off the end; deallocation pushes them back. The ref_count field supports block sharing — when two sequences share a prefix-cached block, its ref_count is 2, and it is only freed when both sequences release it.
Prefix Cache Lookup
Prefix caching is the idea that two requests with the same prompt prefix can share the same KV cache blocks. nano-vllm implements this with a hash map:
def allocate_with_prefix_cache(self, seq: Sequence, token_ids: List[int]):
blocks_needed = (len(token_ids) + self.block_size - 1) // self.block_size
for i in range(blocks_needed):
start = i * self.block_size
end = min(start + self.block_size, len(token_ids))
block_tokens = tuple(token_ids[start:end])
block_hash = hash(block_tokens)
if block_hash in self.block_hash_map:
# Cache hit — reuse existing block
cached_block = self.block_hash_map[block_hash]
cached_block.ref_count += 1
seq.block_table.append(cached_block.block_id)
else:
# Cache miss — allocate new block
block = self.free_blocks.pop()
block.ref_count = 1
block.token_ids = list(block_tokens)
self.block_hash_map[block_hash] = block
seq.block_table.append(block.block_id)
The hash is computed over the token content of each block. If two requests share the same system prompt (e.g., “You are a helpful assistant…”), the blocks covering that prefix will have identical hashes and be shared. This saves both memory and prefill compute — the shared blocks already have their KV cache computed.
Block Table: Logical to Physical Mapping
The block table is the per-sequence data structure that maps logical positions to physical blocks:
Sequence token_ids: [t0, t1, t2, ..., t47]
Block size: 16
Logical block 0 → Physical block 14 (tokens t0..t15)
Logical block 1 → Physical block 7 (tokens t16..t31)
Logical block 2 → Physical block 22 (tokens t32..t47)
seq.block_table = [14, 7, 22]
During the attention computation, the model needs to read KV cache for all previous tokens. Instead of reading from a contiguous buffer, it uses the block table to gather the correct physical blocks. This is the “paged” part of PagedAttention — the attention kernel follows the block table indirection, just like a CPU follows page table entries.
Memory Budget: From Warmup to Block Count
How does the engine decide how many blocks to create? The process works like this:
- Model loading — Load the model weights onto the GPU. Note how much memory they consume.
- Warmup profiling — Run a dummy forward pass with a representative batch size. Measure peak GPU memory usage (weights + activations + CUDA overhead).
- Remaining memory —
total_gpu_memory - peak_warmup_memory = available_for_kv_cache - Block count —
num_blocks = available_for_kv_cache / (block_size * num_layers * 2 * head_dim * num_heads * dtype_size)
The factor of 2 accounts for both K and V caches. Each block stores block_size token slots across all layers, for both keys and values.
# Simplified memory budget calculation
total_memory = torch.cuda.get_device_properties(0).total_mem
model_memory = sum(p.numel() * p.element_size() for p in model.parameters())
activation_memory = profile_peak_memory(model, dummy_batch)
available = total_memory - model_memory - activation_memory
kv_per_block = block_size * num_layers * 2 * num_heads * head_dim * dtype_bytes
num_blocks = int(available * gpu_memory_utilization / kv_per_block)
The gpu_memory_utilization factor (default 0.9 in vLLM) leaves a safety margin to avoid OOM errors.
Mapping to Production vLLM
Production vLLM’s block management is substantially more sophisticated, split across several specialized modules.
BlockPool
The BlockPool replaces nano-vllm’s simple free list with a structured pool that supports:
- FreeKVCacheBlockQueue — a doubly-linked list of free blocks, enabling O(1) allocation and deallocation. Blocks at the tail are “cold” (least recently used) and evicted first when the pool needs to reclaim prefix-cached blocks.
- BlockHashToBlockMap — maps block hashes to physical blocks for prefix cache lookup. Unlike nano-vllm’s simple dict, this handles hash collisions and supports efficient eviction of stale entries.
The free queue’s LRU ordering is critical for prefix caching: when a cached block is freed (ref_count drops to 0), it goes to the tail of the free queue rather than being immediately cleared. If a new request needs the same prefix before the block is evicted, it gets a cache hit for free.
KVCacheManager
The KVCacheManager is the high-level interface that the scheduler calls. It wraps the BlockPool and provides:
allocate_slots()— allocate blocks for a request, checking prefix cache firstfree()— release a request’s blocks back to the poolget_computed_blocks()— find how many of a request’s prompt blocks are already cached (prefix cache hits)can_allocate()— check if enough free blocks exist for a new request
This separation of concerns keeps the scheduler clean — it asks “can I allocate?” and “allocate”, without knowing about hash maps or eviction queues.
KV Cache Data Structures
The building blocks:
- KVCacheBlock — the production equivalent of nano-vllm’s
Block. Carries ablock_id,ref_count,block_hash, and linked-list pointers for the free queue. - BlockHash — a named tuple of
(hash_value, token_ids). The hash is computed incrementally: each block’s hash incorporates the previous block’s hash plus its own tokens, creating a chain that uniquely identifies the full prefix up to that block.
This chained hashing is more robust than nano-vllm’s per-block hash. Two blocks with identical tokens but different preceding contexts will have different hashes, preventing false cache hits.
KVCacheSpec and KVCacheConfig
These configuration classes define the physical layout of the KV cache:
KVCacheSpec— describes the shape of one layer’s KV cache: number of heads, head dimension, dtype. Different attention backends (e.g., FlashAttention, PagedAttention) may have different specs.KVCacheConfig— aggregates specs across all layers and computes the total memory per block, which feeds into the memory budget calculation.
GPU-Side Block Table
On the GPU worker side, the BlockTable class maintains a 2D tensor of shape [max_num_requests, max_num_blocks_per_request] on the GPU. Each row is one request’s block table — the same logical-to-physical mapping we saw in nano-vllm, but stored as a CUDA tensor for direct use by the attention kernel.
The block table is updated incrementally: when the scheduler allocates new blocks for a request, only the new entries are written. This avoids copying the entire table every step.
Prefix Caching: A Worked Example
Step through this interactive demo to see how prefix caching detects shared blocks and avoids redundant computation:
Request A arrives
Request A: "You are a helpful assistant.\nUser: What is Python?" — tokenized into 4 blocks of 16 tokens each. All blocks are cache misses, so we allocate fresh blocks and compute KV cache for each.
"You are a help"A"ful assistant.\n"A"\nUser: What is"A" Python?"A[5, 12, 8, 3]Consider two chat requests with the same system prompt:
Request A: "You are a helpful assistant.\n\nUser: What is Python?"
Request B: "You are a helpful assistant.\n\nUser: What is Rust?"
Assuming a block size of 16 tokens and the system prompt tokenizes to 32 tokens:
Request A arrives:
Block 0: hash("You are a help") → allocate physical block 5, compute KV
Block 1: hash("ful assistant.\n") → allocate physical block 12, compute KV
Block 2: hash("\nUser: What is") → allocate physical block 8, compute KV
Block 3: hash(" Python?") → allocate physical block 3, compute KV
Request B arrives:
Block 0: hash("You are a help") → HIT! reuse physical block 5, skip KV compute
Block 1: hash("ful assistant.\n") → HIT! reuse physical block 12, skip KV compute
Block 2: hash("\nUser: What is") → HIT! reuse physical block 8, skip KV compute
Block 3: hash(" Rust?") → MISS, allocate physical block 17, compute KV
Request B reuses 3 out of 4 blocks. It only needs to compute KV cache for the final block. This saves both GPU memory (shared blocks are stored once) and compute (shared blocks skip prefill).
In production vLLM, the chained hash ensures that block 2’s hash incorporates blocks 0 and 1’s content, so a block with tokens “\nUser: What is” is only reused if the preceding blocks also match.
The Full Picture
Here is how the block management components connect:
Scheduler
│
▼
KVCacheManager ◄── allocate_slots / free / can_allocate
│
▼
BlockPool ◄── physical block allocation + prefix cache lookup
├── FreeKVCacheBlockQueue (free list with LRU ordering)
└── BlockHashToBlockMap (hash → block for prefix caching)
│
▼
KVCacheBlock[] ◄── the actual block objects (id, ref_count, hash)
│
▼
GPU Block Table ◄── 2D tensor [requests × blocks] on GPU
│
▼
Attention Kernel ◄── reads KV cache via block table indirection
nano-vllm collapses the top four layers into a single BlockManager class. Production vLLM fans them out for performance (O(1) free queue operations, incremental GPU updates) and correctness (chained hashing, LRU eviction).
Summary
- PagedAttention manages KV cache as fixed-size blocks, eliminating the memory fragmentation of contiguous allocation
- The block manager maintains a free list and allocates/deallocates blocks as sequences grow and finish
- Each sequence’s block table maps logical positions to physical block IDs, enabling non-contiguous memory layout
- Prefix caching hashes block contents to detect shared prefixes across requests, reusing both memory and computed KV cache
- The memory budget is determined at startup: total GPU memory minus model weights and activations, divided by per-block KV cache size
- Production vLLM adds a doubly-linked free queue (LRU eviction), chained block hashing (context-aware prefix matching), a dedicated KVCacheManager interface, and GPU-side block table tensors for kernel-level access