Chapter 9: Sampling & Decoding

After the model produces logits — a raw score for every token in the vocabulary — the sampler’s job is to turn those scores into an actual next token. This sounds simple, but the sampling layer is where temperature, top-k, top-p, repetition penalties, structured output constraints, and even speculative decoding all come together. nano-vllm keeps it minimal with a clever trick; production vLLM builds a full pipeline.

The Gumbel-Max Trick

nano-vllm’s entire sampler fits in about 12 lines, and it uses the Gumbel-max trick instead of the usual softmax → multinomial approach.

nano-vllm
nanovllm/layers/sampler.py
Sampler — minimal implementation using the Gumbel-max trick for torch.compile-friendly sampling.
class Sampler(nn.Module):
    def forward(self, logits, sampling_params):
        if sampling_params.temperature == 0:
            # Greedy: just take the argmax
            return logits.argmax(dim=-1)

        # Apply temperature
        logits = logits / sampling_params.temperature

        # Gumbel-max trick: add Gumbel noise, then argmax
        # This is equivalent to sampling from softmax(logits)
        u = torch.rand_like(logits)
        gumbel_noise = -torch.log(-torch.log(u + 1e-8) + 1e-8)
        return (logits + gumbel_noise).argmax(dim=-1)

Why the Gumbel-max trick? The standard approach — torch.softmax(logits) followed by torch.multinomial() — involves two operations that are hard for torch.compile to fuse. The Gumbel-max trick replaces them with element-wise operations (rand_like, log, add) and an argmax, all of which torch.compile handles well. The result is mathematically equivalent: sampling from softmax(logits / temperature).

The trick works because of a property of the Gumbel distribution: if you add independent Gumbel(0,1) noise to each logit and take the argmax, the probability of selecting each token is exactly softmax(logits). Scaling logits by 1/temperature before adding noise gives you temperature-controlled sampling.

Temperature Scaling

Temperature controls the “sharpness” of the probability distribution:

  • temperature = 0 — greedy decoding, always pick the highest-probability token
  • temperature < 1 — sharper distribution, more deterministic (the model is more “confident”)
  • temperature = 1 — sample from the model’s natural distribution
  • temperature > 1 — flatter distribution, more random (the model “explores” more)

Mathematically, dividing logits by temperature before softmax scales the exponents: softmax(logits / T). As T approaches 0, the distribution collapses to a point mass on the max logit. As T grows, the distribution approaches uniform.

Top-k and Top-p Filtering

nano-vllm skips top-k/top-p for simplicity, but production vLLM implements both as part of its sampling pipeline.

Top-k keeps only the k highest-probability tokens and zeros out the rest. This prevents the model from sampling extremely unlikely tokens.

Top-p (nucleus sampling) sorts tokens by probability, then keeps the smallest set whose cumulative probability exceeds p. This adapts to the shape of the distribution — when the model is confident, fewer tokens are kept; when uncertain, more are included.

vllm (production)
vllm/v1/sample/ops/topk_topp_sampler.py
Top-k and Top-p sampling kernel — efficient GPU implementation of nucleus sampling.

In vLLM, top-k and top-p are applied as logit masks before the final sampling step. The order matters: top-k first (if set), then top-p on the remaining tokens.

The Full Sampling Pipeline in vLLM

Production vLLM’s sampler is a multi-stage pipeline that processes logits through several transformations before sampling:

vllm (production)
vllm/v1/sample/sampler.py
Full sampling pipeline — logprobs computation, float32 conversion, allowed token masking, penalties, and final sampling.

The pipeline stages are:

  1. Cast to float32 — logits come out of the model in bfloat16/float16 for speed, but sampling needs float32 precision to avoid numerical issues in softmax and log operations.

  2. Compute logprobs — if the request asks for log probabilities (common in evaluation and API responses), they are computed from the raw logits before any modifications.

  3. Apply allowed token mask — for structured output / guided decoding, a mask restricts which tokens are valid at this position (e.g., only tokens that continue a valid JSON string).

  4. Apply penalties — repetition, frequency, and presence penalties modify logits to discourage or encourage certain tokens.

  5. Apply temperature — scale logits by 1/temperature.

  6. Apply top-k / top-p — filter to the most likely tokens.

  7. Sample — draw the final token from the filtered distribution.

Sampling Metadata

Each request can have different sampling parameters. The SamplingMetadata class batches these per-request parameters into tensors so the sampler can process the entire batch in one GPU operation:

vllm (production)
vllm/v1/sample/metadata.py
SamplingMetadata — batches per-request sampling parameters (temperature, top-k, top-p, penalties) into GPU tensors.
vllm (production)
vllm/sampling_params.py
SamplingParams — full parameter specification including temperature, top-k, top-p, penalties, stop conditions, and guided decoding options.

Repetition, Frequency, and Presence Penalties

These penalties discourage the model from repeating itself, each in a slightly different way:

vllm (production)
vllm/v1/sample/ops/penalties.py
Penalty implementations — repetition, frequency, and presence penalties applied to logits.
  • Repetition penalty — divides the logit of any previously generated token by the penalty factor (if the logit is positive) or multiplies it (if negative). A penalty of 1.0 means no change; > 1.0 discourages repetition.

  • Frequency penalty — subtracts penalty * count from each token’s logit, where count is how many times that token has appeared. Tokens that appear more often get penalized more.

  • Presence penalty — subtracts a flat penalty from any token that has appeared at least once, regardless of how many times. This encourages topic diversity.

In vLLM, these penalties are applied as a batched GPU operation. The sampler maintains a token count tensor per sequence and updates it after each generated token.

Structured Output and Guided Decoding

One of vLLM’s powerful features is guided decoding — constraining the model’s output to follow a specific format (JSON schema, regex pattern, grammar). This works by computing an allowed token mask at each step:

  1. A grammar engine (like outlines or lm-format-enforcer) tracks the current state of the output
  2. At each step, it computes which tokens are valid continuations
  3. The sampler sets all invalid token logits to -inf before sampling

This happens at the “apply allowed token mask” stage of the pipeline. The mask is a boolean tensor over the vocabulary — True for allowed tokens, False for forbidden ones.

Speculative Decoding

Speculative decoding is an advanced technique that uses a smaller, faster “draft” model to propose multiple tokens at once, then verifies them with the full model in a single forward pass. When the draft model’s predictions are correct (which happens often), you get multiple tokens for the cost of one large-model forward pass.

vllm (production)
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
Rejection sampler — verifies draft model proposals against the target model's distribution.

The verification uses rejection sampling: for each proposed token, compare the draft model’s probability with the target model’s probability. Accept the token with probability min(1, p_target / p_draft). If rejected, resample from an adjusted distribution. This guarantees the output distribution is identical to what the target model would produce alone.

vLLM supports several proposer strategies:

vllm (production)
vllm/v1/spec_decode/
Speculative decoding proposers — ngram, EAGLE, and Medusa strategies for draft token generation.
  • N-gram proposer — uses patterns in the existing output to predict next tokens (no draft model needed)
  • EAGLE — a lightweight draft head trained on the target model’s hidden states
  • Medusa — multiple draft heads that predict several future tokens in parallel

Mapping to Production vLLM

The jump from nano-vllm’s 12-line sampler to production vLLM’s sampling pipeline reflects the real-world requirements of an inference server:

Featurenano-vllmProduction vLLM
TemperatureYes (Gumbel-max)Yes
Top-k / Top-pNoYes (GPU kernel)
Repetition penaltiesNoYes (rep/freq/presence)
LogprobsNoYes
Structured outputNoYes (grammar-based masking)
Speculative decodingNoYes (ngram, EAGLE, Medusa)
torch.compile friendlyYes (Gumbel-max)Yes

The Gumbel-max trick in nano-vllm is not just a simplification — it is the same approach production vLLM uses for its core sampling step, precisely because it composes well with torch.compile. The production system wraps it with the additional pipeline stages needed for a full-featured API.

Summary

  • Sampling converts model logits into discrete token choices through temperature scaling, filtering, and randomized selection
  • The Gumbel-max trick replaces softmax + multinomial with add noise + argmax, making sampling torch.compile-friendly
  • Production vLLM adds a multi-stage pipeline: float32 cast, logprobs, allowed token masking, penalties, top-k/top-p, then sampling
  • Repetition/frequency/presence penalties each discourage repetition in different ways, applied as batched GPU operations
  • Structured output works by masking invalid tokens to -inf before sampling, guided by a grammar engine
  • Speculative decoding proposes multiple tokens with a fast draft model and verifies them via rejection sampling against the target model