Chapter 9: Sampling & Decoding
After the model produces logits — a raw score for every token in the vocabulary — the sampler’s job is to turn those scores into an actual next token. This sounds simple, but the sampling layer is where temperature, top-k, top-p, repetition penalties, structured output constraints, and even speculative decoding all come together. nano-vllm keeps it minimal with a clever trick; production vLLM builds a full pipeline.
The Gumbel-Max Trick
nano-vllm’s entire sampler fits in about 12 lines, and it uses the Gumbel-max trick instead of the usual softmax → multinomial approach.
class Sampler(nn.Module):
def forward(self, logits, sampling_params):
if sampling_params.temperature == 0:
# Greedy: just take the argmax
return logits.argmax(dim=-1)
# Apply temperature
logits = logits / sampling_params.temperature
# Gumbel-max trick: add Gumbel noise, then argmax
# This is equivalent to sampling from softmax(logits)
u = torch.rand_like(logits)
gumbel_noise = -torch.log(-torch.log(u + 1e-8) + 1e-8)
return (logits + gumbel_noise).argmax(dim=-1)
Why the Gumbel-max trick? The standard approach — torch.softmax(logits) followed by torch.multinomial() — involves two operations that are hard for torch.compile to fuse. The Gumbel-max trick replaces them with element-wise operations (rand_like, log, add) and an argmax, all of which torch.compile handles well. The result is mathematically equivalent: sampling from softmax(logits / temperature).
The trick works because of a property of the Gumbel distribution: if you add independent Gumbel(0,1) noise to each logit and take the argmax, the probability of selecting each token is exactly softmax(logits). Scaling logits by 1/temperature before adding noise gives you temperature-controlled sampling.
Temperature Scaling
Temperature controls the “sharpness” of the probability distribution:
- temperature = 0 — greedy decoding, always pick the highest-probability token
- temperature < 1 — sharper distribution, more deterministic (the model is more “confident”)
- temperature = 1 — sample from the model’s natural distribution
- temperature > 1 — flatter distribution, more random (the model “explores” more)
Mathematically, dividing logits by temperature before softmax scales the exponents: softmax(logits / T). As T approaches 0, the distribution collapses to a point mass on the max logit. As T grows, the distribution approaches uniform.
Top-k and Top-p Filtering
nano-vllm skips top-k/top-p for simplicity, but production vLLM implements both as part of its sampling pipeline.
Top-k keeps only the k highest-probability tokens and zeros out the rest. This prevents the model from sampling extremely unlikely tokens.
Top-p (nucleus sampling) sorts tokens by probability, then keeps the smallest set whose cumulative probability exceeds p. This adapts to the shape of the distribution — when the model is confident, fewer tokens are kept; when uncertain, more are included.
In vLLM, top-k and top-p are applied as logit masks before the final sampling step. The order matters: top-k first (if set), then top-p on the remaining tokens.
The Full Sampling Pipeline in vLLM
Production vLLM’s sampler is a multi-stage pipeline that processes logits through several transformations before sampling:
The pipeline stages are:
-
Cast to float32 — logits come out of the model in bfloat16/float16 for speed, but sampling needs float32 precision to avoid numerical issues in softmax and log operations.
-
Compute logprobs — if the request asks for log probabilities (common in evaluation and API responses), they are computed from the raw logits before any modifications.
-
Apply allowed token mask — for structured output / guided decoding, a mask restricts which tokens are valid at this position (e.g., only tokens that continue a valid JSON string).
-
Apply penalties — repetition, frequency, and presence penalties modify logits to discourage or encourage certain tokens.
-
Apply temperature — scale logits by
1/temperature. -
Apply top-k / top-p — filter to the most likely tokens.
-
Sample — draw the final token from the filtered distribution.
Sampling Metadata
Each request can have different sampling parameters. The SamplingMetadata class batches these per-request parameters into tensors so the sampler can process the entire batch in one GPU operation:
Repetition, Frequency, and Presence Penalties
These penalties discourage the model from repeating itself, each in a slightly different way:
-
Repetition penalty — divides the logit of any previously generated token by the penalty factor (if the logit is positive) or multiplies it (if negative). A penalty of 1.0 means no change; > 1.0 discourages repetition.
-
Frequency penalty — subtracts
penalty * countfrom each token’s logit, wherecountis how many times that token has appeared. Tokens that appear more often get penalized more. -
Presence penalty — subtracts a flat penalty from any token that has appeared at least once, regardless of how many times. This encourages topic diversity.
In vLLM, these penalties are applied as a batched GPU operation. The sampler maintains a token count tensor per sequence and updates it after each generated token.
Structured Output and Guided Decoding
One of vLLM’s powerful features is guided decoding — constraining the model’s output to follow a specific format (JSON schema, regex pattern, grammar). This works by computing an allowed token mask at each step:
- A grammar engine (like
outlinesorlm-format-enforcer) tracks the current state of the output - At each step, it computes which tokens are valid continuations
- The sampler sets all invalid token logits to
-infbefore sampling
This happens at the “apply allowed token mask” stage of the pipeline. The mask is a boolean tensor over the vocabulary — True for allowed tokens, False for forbidden ones.
Speculative Decoding
Speculative decoding is an advanced technique that uses a smaller, faster “draft” model to propose multiple tokens at once, then verifies them with the full model in a single forward pass. When the draft model’s predictions are correct (which happens often), you get multiple tokens for the cost of one large-model forward pass.
The verification uses rejection sampling: for each proposed token, compare the draft model’s probability with the target model’s probability. Accept the token with probability min(1, p_target / p_draft). If rejected, resample from an adjusted distribution. This guarantees the output distribution is identical to what the target model would produce alone.
vLLM supports several proposer strategies:
- N-gram proposer — uses patterns in the existing output to predict next tokens (no draft model needed)
- EAGLE — a lightweight draft head trained on the target model’s hidden states
- Medusa — multiple draft heads that predict several future tokens in parallel
Mapping to Production vLLM
The jump from nano-vllm’s 12-line sampler to production vLLM’s sampling pipeline reflects the real-world requirements of an inference server:
| Feature | nano-vllm | Production vLLM |
|---|---|---|
| Temperature | Yes (Gumbel-max) | Yes |
| Top-k / Top-p | No | Yes (GPU kernel) |
| Repetition penalties | No | Yes (rep/freq/presence) |
| Logprobs | No | Yes |
| Structured output | No | Yes (grammar-based masking) |
| Speculative decoding | No | Yes (ngram, EAGLE, Medusa) |
| torch.compile friendly | Yes (Gumbel-max) | Yes |
The Gumbel-max trick in nano-vllm is not just a simplification — it is the same approach production vLLM uses for its core sampling step, precisely because it composes well with torch.compile. The production system wraps it with the additional pipeline stages needed for a full-featured API.
Summary
- Sampling converts model logits into discrete token choices through temperature scaling, filtering, and randomized selection
- The Gumbel-max trick replaces
softmax + multinomialwithadd noise + argmax, making samplingtorch.compile-friendly - Production vLLM adds a multi-stage pipeline: float32 cast, logprobs, allowed token masking, penalties, top-k/top-p, then sampling
- Repetition/frequency/presence penalties each discourage repetition in different ways, applied as batched GPU operations
- Structured output works by masking invalid tokens to
-infbefore sampling, guided by a grammar engine - Speculative decoding proposes multiple tokens with a fast draft model and verifies them via rejection sampling against the target model