Chapter 8: Model Runner & CUDA Graphs

The Model Runner is where scheduling decisions become GPU computation. It takes the scheduler’s output, builds the right tensor inputs for prefill or decode, runs the model forward pass, and — critically — uses CUDA Graphs to eliminate kernel launch overhead during decoding. This chapter walks through how nano-vllm implements this, then maps it to the production vLLM codebase.

The ModelRunner Class

In nano-vllm, ModelRunner owns the entire execution pipeline: KV cache allocation, input preparation, model execution, and CUDA Graph capture.

nano-vllm
nanovllm/engine/model_runner.py
ModelRunner — handles KV cache allocation, input preparation, model execution, and CUDA Graph optimization.

The class is initialized with the model, tokenizer, and cache configuration. Before any real inference happens, it goes through a warmup phase to profile GPU memory usage.

Warmup and Memory Profiling

Before serving requests, the engine needs to know how much GPU memory is available for KV cache blocks. The approach is straightforward: run a dummy forward pass, measure peak memory usage, then allocate the remaining memory to KV cache.

def warmup_model(self):
    """Run a dummy forward pass to profile peak GPU memory usage."""
    input_ids = torch.zeros(1, dtype=torch.long, device=self.device)
    positions = torch.zeros(1, dtype=torch.long, device=self.device)
    self.model(input_ids, positions, self.kv_caches, is_prefill=True)
    torch.cuda.synchronize()

This warmup_model() triggers PyTorch to allocate all internal buffers (activation memory, weight gradients, etc.). After this call, torch.cuda.mem_get_info() tells us exactly how much free memory remains.

KV Cache Allocation

With the memory profile in hand, allocate_kv_cache() computes how many blocks fit in the remaining GPU memory:

def allocate_kv_cache(self):
    free_memory, _ = torch.cuda.mem_get_info()
    # Reserve some memory for runtime overhead
    usable = int(free_memory * 0.9)

    # Each block stores [block_size] tokens for all layers
    # Shape per layer: [num_blocks, block_size, num_kv_heads, head_dim]
    block_memory = (
        self.num_layers * 2 *  # 2 for K and V
        self.block_size *
        self.num_kv_heads *
        self.head_dim *
        self.dtype_size
    )
    num_blocks = usable // block_memory

    self.kv_caches = []
    for _ in range(self.num_layers):
        k_cache = torch.zeros(num_blocks, self.block_size, self.num_kv_heads, self.head_dim, ...)
        v_cache = torch.zeros(num_blocks, self.block_size, self.num_kv_heads, self.head_dim, ...)
        self.kv_caches.append((k_cache, v_cache))

This is a one-time allocation. The block manager (covered in Chapter 5) then hands out blocks from this pool as sequences arrive.

Input Preparation: Prefill vs Decode

The most important distinction in the model runner is how it builds inputs for prefill versus decode. These two phases have fundamentally different tensor shapes.

Prefill Inputs

During prefill, we process the entire prompt at once. Multiple sequences may be batched together, but each has a variable-length prompt. The prepare_prefill() method builds the inputs:

def prepare_prefill(self, scheduler_output):
    input_ids = []
    positions = []
    slot_mapping = []
    cu_seqlens_q = [0]
    cu_seqlens_k = [0]

    for seq in scheduler_output.prefill_sequences:
        tokens = seq.get_token_ids()
        seq_len = len(tokens)

        input_ids.extend(tokens)
        positions.extend(range(seq_len))

        # Map each token position to its KV cache slot
        for i in range(seq_len):
            block_idx = seq.block_table[i // self.block_size]
            block_offset = i % self.block_size
            slot_mapping.append(block_idx * self.block_size + block_offset)

        cu_seqlens_q.append(cu_seqlens_q[-1] + seq_len)
        cu_seqlens_k.append(cu_seqlens_k[-1] + seq_len)

    return ModelInput(
        input_ids=torch.tensor(input_ids, device=self.device),
        positions=torch.tensor(positions, device=self.device),
        cu_seqlens_q=torch.tensor(cu_seqlens_q, dtype=torch.int32, device=self.device),
        cu_seqlens_k=torch.tensor(cu_seqlens_k, dtype=torch.int32, device=self.device),
        slot_mapping=torch.tensor(slot_mapping, device=self.device),
        is_prefill=True,
    )

Key details:

  • Flattened input_ids — all sequences are concatenated into a single 1D tensor. No padding needed.
  • cu_seqlens_q / cu_seqlens_k — cumulative sequence lengths, used by FlashAttention to know where each sequence starts and ends in the flattened tensor.
  • slot_mapping — maps each token position to its physical slot in the KV cache. This is how the attention kernel knows where to write the K/V values.

Decode Inputs

During decode, each sequence generates exactly one new token. The prepare_decode() method builds a much simpler set of inputs:

def prepare_decode(self, scheduler_output):
    input_ids = []
    positions = []
    slot_mapping = []
    context_lens = []
    block_tables = []

    for seq in scheduler_output.decode_sequences:
        # Only the last generated token
        input_ids.append(seq.get_last_token_id())
        positions.append(seq.get_len() - 1)

        # Slot for the new token
        pos = seq.get_len() - 1
        block_idx = seq.block_table[pos // self.block_size]
        block_offset = pos % self.block_size
        slot_mapping.append(block_idx * self.block_size + block_offset)

        context_lens.append(seq.get_len())
        block_tables.append(seq.block_table)

    return ModelInput(
        input_ids=torch.tensor(input_ids, device=self.device),
        positions=torch.tensor(positions, device=self.device),
        slot_mapping=torch.tensor(slot_mapping, device=self.device),
        context_lens=torch.tensor(context_lens, dtype=torch.int32, device=self.device),
        block_tables=padded_block_tables,
        is_prefill=False,
    )

Key differences from prefill:

  • One token per sequenceinput_ids has exactly one entry per sequence
  • block_tables — the decode attention kernel needs the full block table to read all previous K/V values
  • context_lens — tells the kernel how many past tokens each sequence has

Global Context for Attention Metadata

nano-vllm uses a global Context dataclass to pass attention metadata (like cu_seqlens, block_tables, is_prefill) down to the attention layers without threading it through every function signature.

nano-vllm
nanovllm/utils/context.py
Global Context dataclass — carries attention metadata (cu_seqlens, block_tables, slot_mapping) through the model.

This is a pragmatic shortcut. The attention layer reads from this global context to decide whether to use prefill-style (variable-length FlashAttention) or decode-style (paged attention) computation.

Model Execution and CUDA Graphs

The run_model() method is where the forward pass actually happens:

def run_model(self, model_input):
    if model_input.is_prefill:
        # Prefill: variable-length input, cannot use CUDA graphs
        logits = self.model(
            model_input.input_ids,
            model_input.positions,
            self.kv_caches,
            is_prefill=True,
        )
    else:
        # Decode: fixed-shape input, use CUDA graph replay
        batch_size = len(model_input.input_ids)
        graph_batch_size = self._get_padded_batch_size(batch_size)

        if graph_batch_size in self.cuda_graphs:
            logits = self._run_with_cuda_graph(model_input, graph_batch_size)
        else:
            logits = self.model(
                model_input.input_ids,
                model_input.positions,
                self.kv_caches,
                is_prefill=False,
            )

    return logits

The key insight: prefill inputs have variable shapes (different prompt lengths), so they must run eagerly. But decode inputs always have the same shape for a given batch size (one token per sequence), making them perfect candidates for CUDA Graphs.

CUDA Graph Capture

CUDA Graphs are a GPU optimization that records a sequence of kernel launches into a graph, then replays the entire graph in a single launch. This eliminates the CPU-side overhead of launching individual kernels — which matters a lot when decode steps are fast (often < 1ms of GPU time) but kernel launch overhead is significant.

The capture_cudagraph() method pre-captures graphs at power-of-2 batch sizes:

def capture_cudagraph(self):
    """Pre-capture CUDA graphs for common decode batch sizes."""
    self.cuda_graphs = {}

    # Capture at power-of-2 batch sizes: 1, 2, 4, 8, ..., max_batch_size
    batch_sizes = []
    bs = 1
    while bs <= self.max_batch_size:
        batch_sizes.append(bs)
        bs *= 2

    for bs in batch_sizes:
        # Create dummy inputs with the exact shape
        input_ids = torch.zeros(bs, dtype=torch.long, device=self.device)
        positions = torch.zeros(bs, dtype=torch.long, device=self.device)

        # Warmup run (required before capture)
        self.model(input_ids, positions, self.kv_caches, is_prefill=False)

        # Capture the graph
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph):
            output = self.model(input_ids, positions, self.kv_caches, is_prefill=False)

        self.cuda_graphs[bs] = {
            "graph": graph,
            "input_ids": input_ids,
            "positions": positions,
            "output": output,
        }

Why power-of-2 sizes? CUDA Graphs require fixed tensor shapes. Rather than capturing a graph for every possible batch size, we capture at powers of 2 and pad the actual batch up to the next power of 2. This gives us O(log N) graphs instead of O(N).

Graph Replay

When a decode step runs, the model runner:

  1. Pads the batch size up to the next power of 2
  2. Copies the real input data into the graph’s pre-allocated input tensors
  3. Replays the graph (a single GPU operation)
  4. Reads the output from the graph’s pre-allocated output tensor, slicing off the padding

This avoids re-launching hundreds of individual CUDA kernels (attention, linear layers, layer norms, etc.) and can reduce decode latency by 10-30%.

Mapping to Production vLLM

Production vLLM follows the same pattern but with significantly more sophistication:

vllm (production)
vllm/v1/worker/gpu_model_runner.py
Production GPU ModelRunner — execute_model() entry point, input preparation, CUDA graph capture and replay.
vllm (production)
vllm/v1/worker/gpu_worker.py
GPU Worker — manages the model runner in a separate process, handles memory profiling and initialization.

Key differences in production vLLM:

  • Multi-GPU support — the Worker class runs in a separate process per GPU, coordinating via tensor parallelism. The model runner on each worker prepares inputs for its shard.
  • Chunked prefill — long prompts can be split across multiple steps to avoid blocking decode requests. The model runner handles mixed prefill+decode batches.
  • CUDAGraph with torch.compile — vLLM combines CUDA Graphs with torch.compile for additional kernel fusion, getting the best of both optimizations.
  • Encoder-decoder models — the input preparation handles encoder inputs (for models like Whisper or T5) alongside decoder inputs.
  • Speculative decoding integration — the model runner can execute both the draft model and the target model, managing the additional tensor bookkeeping for verification.
  • Detailed profiling — memory profiling accounts for LoRA adapters, multi-modal inputs, and KV cache quantization.

Despite the added complexity, the core flow is identical: prepare inputs from scheduler output, run the model (with CUDA Graph replay for decode), return logits.

Summary

  • The ModelRunner bridges scheduling decisions and GPU execution by building the right tensor inputs for each step
  • Prefill inputs are variable-length (flattened tokens + cumulative sequence lengths); decode inputs are fixed-shape (one token per sequence + block tables)
  • CUDA Graphs pre-record decode kernels at power-of-2 batch sizes, then replay them in a single launch to eliminate kernel launch overhead
  • Memory profiling via a dummy forward pass determines how many KV cache blocks fit in GPU memory
  • Production vLLM adds multi-GPU workers, chunked prefill, torch.compile integration, and speculative decoding on top of this same architecture