Chapter 10: Weight Loading

Before a model can serve a single request, its weights need to get from disk into GPU memory — and for tensor-parallel inference, each GPU needs exactly the right shard. This chapter covers how nano-vllm loads HuggingFace checkpoints into a sharded model, and how production vLLM extends this with quantization, multiple formats, and distributed loading.

The Weight Loader Protocol

The central idea in both nano-vllm and vLLM is that each model parameter knows how to load itself. Instead of a monolithic loader that understands every layer’s sharding pattern, each linear layer has a weight_loader method that accepts a raw checkpoint tensor and extracts the correct shard.

This inversion of control is what makes the system extensible — adding a new model architecture only requires defining the right weight_loader methods on its layers.

nano-vllm’s Loader

nano-vllm’s entire weight loading logic fits in about 28 lines:

nano-vllm
nanovllm/utils/loader.py
load_model() — iterates safetensors shards, maps checkpoint names to model parameters, and calls per-parameter weight_loader methods.
def load_model(model, model_name):
    # Mapping from checkpoint weight names to model parameter names
    params = dict(model.named_parameters())

    # Some checkpoint weights map to multiple model parameters (packed/fused)
    # e.g., "model.layers.0.self_attn.qkv_proj.weight" contains Q, K, V fused together
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
    }

    for name, checkpoint_tensor in iterate_safetensors(model_name):
        # Check if this is a packed/fused weight
        for packed_name, sub_modules in packed_modules_mapping.items():
            if packed_name in name:
                # This single checkpoint tensor feeds multiple parameters
                for sub_name in sub_modules:
                    param_name = name.replace(packed_name, sub_name)
                    if param_name in params:
                        param = params[param_name]
                        param.weight_loader(param, checkpoint_tensor, sub_name)
                break
            else:
                # Direct 1:1 mapping
                if name in params:
                    param = params[name]
                    param.weight_loader(param, checkpoint_tensor)

The flow is:

  1. Iterate safetensors — walk through each tensor in the checkpoint file(s)
  2. Name mapping — match checkpoint tensor names to model parameter names
  3. Packed modules lookup — check if this checkpoint tensor is a fused weight that maps to multiple model parameters
  4. Call weight_loader — each parameter’s weight_loader method extracts the right shard from the checkpoint tensor

Safetensors Format

Modern HuggingFace models use the safetensors format — a simple, memory-mappable binary format that stores named tensors. Unlike pickle-based .bin files, safetensors files can be read without executing arbitrary code, and individual tensors can be loaded without reading the entire file.

For large models, the checkpoint is split across multiple safetensors shards (e.g., model-00001-of-00003.safetensors). The loader iterates through all shards, yielding (name, tensor) pairs.

Packed Modules Mapping

Many transformer implementations fuse related projections for efficiency. For example, instead of separate Q, K, V linear layers, the model uses a single qkv_proj that computes all three in one matrix multiplication. Similarly, gate_up_proj fuses the gate and up projections in the MLP.

But HuggingFace checkpoints often store these as separate weights (q_proj.weight, k_proj.weight, v_proj.weight). The packed modules mapping tells the loader how to reassemble them:

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"],
    "gate_up_proj": ["gate_proj", "up_proj"],
}

When the model has a fused qkv_proj parameter, and the checkpoint provides separate q_proj, k_proj, v_proj tensors, the weight_loader on the fused parameter knows how to slot each piece into the right position within the larger fused weight matrix.

Conversely, if the checkpoint stores a fused weight and the model expects separate parameters, the weight_loader extracts the appropriate slice.

The weight_loader Method

Each parameter’s weight_loader is where tensor-parallel sharding happens. For a column-parallel linear layer (like the Q/K/V projections), the weight_loader:

  1. Takes the full checkpoint tensor
  2. Determines the current TP rank
  3. Slices out the correct shard along the output dimension
  4. Copies it into the parameter
# Conceptual weight_loader for a column-parallel linear layer
def weight_loader(self, param, loaded_weight, loaded_shard_id=None):
    tp_rank = get_tensor_parallel_rank()
    tp_size = get_tensor_parallel_world_size()

    # Shard along the output dimension
    shard_size = loaded_weight.shape[0] // tp_size
    start = tp_rank * shard_size
    end = start + shard_size

    if loaded_shard_id is not None:
        # This is part of a packed module — place it at the right offset
        shard_offset = self.packed_offsets[loaded_shard_id]
        param.data[shard_offset:shard_offset + shard_size].copy_(
            loaded_weight[start:end]
        )
    else:
        param.data.copy_(loaded_weight[start:end])

This protocol means the loader itself is model-agnostic. It just iterates tensors and calls weight_loader — the parameter knows its own sharding strategy.

Mapping to Production vLLM

Production vLLM builds a much more sophisticated loading system on top of the same core protocol:

vllm (production)
vllm/model_executor/model_loader/default_loader.py
DefaultModelLoader — production weight loading with support for multiple formats, distributed loading, and quantization.
vllm (production)
vllm/model_executor/model_loader/weight_utils.py
Weight utilities — downloading checkpoints, iterating safetensors shards, and weight name remapping.
vllm (production)
vllm/model_executor/model_loader/__init__.py
get_model_loader() factory — selects the right loader based on model config (default, dummy, sharded, etc.).

Loader Factory

vLLM uses a factory pattern to select the right loader:

  • DefaultModelLoader — standard loading from HuggingFace checkpoints
  • ShardedStateLoader — loads pre-sharded checkpoints (already split per TP rank)
  • DummyModelLoader — creates random weights for testing and benchmarking

The get_model_loader() factory inspects the model config and load format to pick the appropriate loader.

Weight Downloading

Production vLLM handles the full lifecycle of getting weights from HuggingFace Hub to GPU memory. The weight_utils module manages:

  • Downloading model files from HuggingFace Hub (with caching)
  • Handling both safetensors and legacy .bin formats
  • Iterating through multi-shard checkpoints
  • Weight name remapping for architecture variations

Quantized Loading

One of vLLM’s most important production features is loading quantized models — models whose weights have been compressed to use fewer bits per parameter. This dramatically reduces memory usage and can improve throughput.

vllm (production)
vllm/model_executor/layers/quantization/
Quantization methods — FP8, GPTQ, AWQ, BitsAndBytes, and more. Each method defines how to load and dequantize weights.

Each quantization method implements its own weight loading logic:

  • FP8 — weights stored in 8-bit floating point. The weight_loader reads FP8 tensors and optional per-channel scales. At inference time, weights are dequantized to bfloat16/float16 on the fly, or the matmul kernel operates directly in FP8.

  • GPTQ — weights are quantized to 2/3/4/8 bits using a calibration dataset. The loader reads the quantized weight matrix, zero points, and scales, then packs them into the format expected by the GPTQ CUDA kernel.

  • AWQ (Activation-aware Weight Quantization) — similar to GPTQ but optimized for activation distributions. The loader handles AWQ’s specific packing format and scale layout.

  • BitsAndBytes — on-the-fly quantization to 4-bit or 8-bit. Unlike GPTQ/AWQ, the weights are stored in full precision and quantized during loading, making it easy to use with any model.

The quantization method is selected based on the model’s config (the quantization_config field in config.json). Each method provides a custom weight_loader that knows how to handle its specific format.

Key Differences from nano-vllm

Featurenano-vllmProduction vLLM
Safetensors iterationYesYes + legacy .bin support
Packed modules mappingYesYes (more architectures)
weight_loader protocolYesYes (same pattern)
Quantized loadingNoFP8, GPTQ, AWQ, BitsAndBytes, etc.
Distributed loadingBasic TP shardingMulti-node, pipeline parallel
Weight downloadingAssumes localHuggingFace Hub integration
Pre-sharded checkpointsNoYes (ShardedStateLoader)
Model architecturesLlama-family50+ architectures

Summary

  • Weight loading follows an inversion-of-control pattern: the loader iterates checkpoint tensors, but each parameter’s weight_loader method knows how to extract its own shard
  • Packed modules mapping handles the mismatch between fused model parameters (like qkv_proj) and separate checkpoint weights (like q_proj, k_proj, v_proj)
  • Safetensors provides safe, memory-mappable tensor storage that supports lazy per-tensor loading
  • Production vLLM extends this with a loader factory, HuggingFace Hub integration, and quantized loading (FP8, GPTQ, AWQ, BitsAndBytes) that can cut memory usage by 2-4x
  • The weight_loader protocol is what makes vLLM’s model support so extensible — adding a new architecture means defining weight_loaders, not modifying the loader itself