Chapter 3: Scheduling & Continuous Batching

The scheduler is the brain of the inference engine. Every step(), it decides which requests to run, how many tokens to process for each, and whether to evict requests when memory runs out. This chapter starts with nano-vllm’s 84-line scheduler to build intuition, then maps to vLLM’s production scheduler with chunked prefill, priority queues, and unified scheduling.

Why Scheduling Matters

Without a scheduler, you would process requests one at a time — finish the entire generation for request A before starting request B. This is simple but wasteful: during decode, the GPU processes just one token per request, leaving most of its compute capacity idle.

Continuous batching solves this by dynamically merging requests at different stages into a single batch. A request that just arrived (needs prefill) can share a batch with requests mid-generation (decode). As soon as one request finishes, a new one can take its slot — no waiting for the entire batch to complete.

The scheduler is what makes this possible.

nano-vllm’s Scheduler

nano-vllm
nanovllm/engine/scheduler.py
Scheduler class — 84 lines implementing prefill-priority scheduling, decode batching, and preemption.

nano-vllm’s scheduler manages two queues and produces a SchedulerOutput each step:

class Scheduler:
    def __init__(self, cache_config):
        self.block_manager = BlockManager(cache_config)
        self.waiting: List[Sequence] = []    # new requests awaiting prefill
        self.running: List[Sequence] = []    # active requests in decode phase

    def add_request(self, request_id, token_ids, sampling_params):
        seq = Sequence(request_id, token_ids, sampling_params)
        self.waiting.append(seq)

Two queues, two phases:

  • waiting — sequences that have not yet been prefilled. They need their full prompt processed in one (or more) forward passes.
  • running — sequences that have completed prefill and are generating tokens one at a time.

The schedule() Method

Each call to schedule() builds a batch by first trying to admit waiting requests (prefill), then including all running requests (decode):

def schedule(self):
    # Phase 1: Try to move waiting requests into running (prefill)
    scheduled_prefills = []
    while self.waiting:
        seq = self.waiting[0]
        num_blocks_needed = self._get_num_blocks(seq)
        if not self.block_manager.can_allocate(num_blocks_needed):
            break  # not enough memory — stop admitting
        self.block_manager.allocate(seq, num_blocks_needed)
        seq.status = SequenceStatus.RUNNING
        self.waiting.pop(0)
        self.running.append(seq)
        scheduled_prefills.append(seq)

    # Phase 2: Include all running requests (decode)
    scheduled_decodes = [s for s in self.running if s not in scheduled_prefills]

    # Phase 3: Preempt if we cannot allocate decode blocks
    while scheduled_decodes and not self._can_append_slots(scheduled_decodes):
        victim = scheduled_decodes.pop()
        self._preempt(victim)

    return SchedulerOutput(prefills=scheduled_prefills, decodes=scheduled_decodes)

This is a prefill-priority design: new requests are admitted first, then existing decode requests fill the remaining capacity. The logic is straightforward:

  1. Walk the waiting queue front-to-back. For each request, check if the block manager has enough free blocks for its prompt. If yes, allocate blocks and move it to running.
  2. Collect all previously-running sequences as decode candidates.
  3. If there is not enough memory to give each decode sequence its next block, preempt the lowest-priority (last-added) sequences until things fit.

Preemption

When memory is tight, the scheduler must evict running sequences to make room:

def _preempt(self, seq: Sequence):
    self.block_manager.deallocate(seq)
    seq.status = SequenceStatus.WAITING
    seq.block_table.clear()
    self.running.remove(seq)
    self.waiting.insert(0, seq)  # re-add at front for fairness

The preempted sequence loses its KV cache blocks and goes back to the waiting queue. When it is re-scheduled, it will need a full re-prefill. This is the simplest preemption strategy — production vLLM also supports swapping KV cache to CPU memory to avoid recomputation, but the principle is the same: free GPU memory by sacrificing a running request.

Prefill vs Decode: Two Different Workloads

Understanding why the scheduler treats prefill and decode differently is key:

PrefillDecode
Tokens processedAll prompt tokens (hundreds to thousands)1 new token per request
Compute profileCompute-bound (large matrix multiplications)Memory-bound (reading KV cache)
KV cacheAllocated upfront for the full promptGrows by 1 token per step
Latency goalTime-to-first-token (TTFT)Inter-token latency (ITL)

The scheduler balances these two workloads. Admitting too many prefills starves decode requests (hurting ITL). Admitting too few delays new requests (hurting TTFT).

Mapping to Production vLLM

Production vLLM’s v1 scheduler is a significant evolution. Let’s trace the key differences.

Unified Scheduler

vllm (production)
vllm/v1/core/sched/scheduler.py
Unified scheduler — no separate prefill/decode phases, supports chunked prefill and priority scheduling.

The production Scheduler does not separate prefill and decode into distinct phases. Instead, it treats every request uniformly:

  • A request with num_computed_tokens < num_prompt_tokens needs prefill work.
  • A request with num_computed_tokens >= num_prompt_tokens needs decode work.
  • Both types are scheduled together in a single pass.

This unified approach enables chunked prefill: a long prompt does not need to be processed all at once. The scheduler can allocate a budget of, say, 2048 tokens per step, and split that budget across prefill chunks and decode tokens:

Step 1: [Request A prefill chunk: 1500 tokens] [Request B decode: 1 token] [Request C decode: 1 token] ...
Step 2: [Request A prefill chunk: 500 tokens]  [Request D prefill: 800 tokens] [Request B decode: 1 token] ...
Step 3: [Request A decode: 1 token]            [Request D prefill chunk: 1200 tokens] [Request B decode: 1 token] ...

This prevents a single long prompt from blocking all decode requests for an entire step, dramatically improving inter-token latency for concurrent requests.

Request Queue: FCFS and Priority

vllm (production)
vllm/v1/core/sched/request_queue.py
Request queue implementations — FCFS (default) and priority-based ordering.

nano-vllm uses a simple list for its waiting queue (implicitly FCFS). Production vLLM provides pluggable queue strategies:

  • FCFSRequestQueue — first-come, first-served. Requests are processed in arrival order. This is the default and works well for most serving scenarios.
  • PriorityRequestQueue — requests carry a priority value, and higher-priority requests are scheduled first. Useful for tiered serving where premium users get lower latency.

SchedulerOutput

vllm (production)
vllm/v1/core/sched/output.py
SchedulerOutput — the structured output from the scheduler to the model runner.

The SchedulerOutput in production vLLM carries much more information than nano-vllm’s simple list of sequences:

  • num_scheduled_tokens — a per-request dict specifying how many tokens to process (enables chunked prefill)
  • total_num_scheduled_tokens — the total token budget consumed this step
  • scheduled_new_reqs — newly admitted requests with their full metadata
  • scheduled_cached_reqs — continuing requests that only need a token count update
  • preempted_req_ids — requests that were evicted this step
  • finished_req_ids — requests completed this step

This structured output lets the model runner build its input tensors without needing to inspect individual request objects.

Continuous Batching in Action

To see how continuous batching works end-to-end, consider three requests arriving at different times:

Time    Waiting Queue    Running Batch              Action
────    ─────────────    ─────────────              ──────
t=0     [A(512 tok)]     []                         Prefill A
t=1     [B(256 tok)]     [A(decode)]                Prefill B, Decode A
t=2     []               [A(decode), B(decode)]     Decode A+B
t=3     [C(128 tok)]     [A(decode), B(decode)]     Prefill C, Decode A+B
t=4     []               [A(done!), B, C]           A finishes, slot freed
t=5     [D(200 tok)]     [B(decode), C(decode)]     Prefill D into A's slot

The key insight: request D does not wait for B and C to finish. It takes A’s freed slot immediately. The batch size fluctuates dynamically, and the GPU stays busy.

Compare this to static batching (the naive approach), where you would wait for all of A, B, C to finish before starting a new batch. If A generates 10 tokens and C generates 500, A’s GPU slot sits idle for 490 steps.

Chunked Prefill: Handling Long Prompts

Without chunked prefill, a request with a 32K-token prompt would monopolize the GPU for one very long step, stalling all decode requests. Chunked prefill breaks this up:

# Simplified chunked prefill logic (production vLLM)
token_budget = max_num_batched_tokens  # e.g., 2048

for req in request_queue:
    if token_budget <= 0:
        break
    if req.num_computed_tokens < req.num_prompt_tokens:
        # Prefill: allocate a chunk
        remaining = req.num_prompt_tokens - req.num_computed_tokens
        chunk_size = min(remaining, token_budget)
        schedule_tokens(req, chunk_size)
        token_budget -= chunk_size
    else:
        # Decode: costs 1 token
        schedule_tokens(req, 1)
        token_budget -= 1

The token_budget is the total number of tokens the model can process in one step (limited by GPU memory and compute). Prefill chunks and decode tokens compete for this budget, and the scheduler distributes it according to its policy.

Preemption Strategies

nano-vllm uses the simplest preemption: deallocate all blocks and re-prefill later. Production vLLM offers more sophisticated options:

  1. Recompute (same as nano-vllm) — discard the KV cache and re-run prefill when the request is re-scheduled. Simple but wastes compute.

  2. Swap — copy the KV cache blocks to CPU memory. When the request is re-scheduled, swap them back to GPU. Saves compute at the cost of CPU memory and PCIe bandwidth.

The choice depends on the workload. Short prompts are cheap to recompute; long prompts benefit from swapping.

Summary

  • The scheduler runs every step(), deciding which requests to prefill and which to decode
  • nano-vllm uses a prefill-priority design: admit new requests first, then batch all running decodes, preempting if memory is tight
  • Continuous batching dynamically merges requests at different stages, keeping the GPU busy as requests arrive and finish
  • Chunked prefill (production vLLM) breaks long prompts into pieces so they do not block decode requests
  • Production vLLM’s unified scheduler treats prefill and decode uniformly with a shared token budget, supports FCFS and priority queues, and produces a structured SchedulerOutput for the model runner
  • Preemption is the escape valve: when memory runs out, the scheduler evicts low-priority requests, either discarding their KV cache (recompute) or swapping it to CPU