technical
petite-vllm · llm-serving · attention · vllm

petite-vllm Part 2: KV Cache & Paged Attention

Building an LLM Serving Engine from Scratch

· 12 min read

KV Cache

In part 1 we implemented a simple autoregressive loop and LLM interface. There was no KV caching, which means that at each token generation step, we would recompute the previous KK and VV projections.

for tok_id in range(max_tokens):
    positions = torch.arange(toks.shape[1]) # get positions based on current_seq_len

    all_logits = self.model.forward(toks, positions)
    last_logits = all_logits[:, -1, :]
    new_tkn = sample(last_logits, temperature, top_k)
    toks = torch.cat([toks, new_tkn.unsqueeze(0)], dim=-1)

To get some intuition on why we need a KV cache, lets take a quick detour into the details of attention and establish some context.

The attention formula:

Attention(Q,K,V)=Softmax(QKTdhead)V\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_{\text{head}}}}\right) V

Where QQ, KK, and VV are the Queries, Keys, and Values matrices.

Skipping over the details of how this formula is derived, QQ represents the current token coming in that we want to predict a next value for, and KK and VV represent what we know about all tokens we've seen so far (including the current token).

We take the dot product of QQ and KK to get a similarity score between what we've seen so far (the context), and our current active token. We pipe this through a softmax to turn these into a probability distribution, and then multiply by VV to produce a weighted combination of value vectors — tokens with higher relevance contribute more to the output.

This is a highly simplified explanation of what the attention mechanism is doing, but it hopefully gives some intuition as to why KK and VV are kind of a big deal. They depend on every token we've seen so far and thus scale linearly across sequence length and batch dimensions.

So what exactly are QQ, KK, VV and how are they produced? QQ, KK and VV are the output of matrix multiplying XX with projection weight matrices WqW_q, WkW_k, and WvW_v.

Without a KV cache this requires the matmuls of each to have the following shape, where sactives_{\text{active}} is the number of new tokens coming in, and spriors_{\text{prior}} is the size of the context.

Q=XWq:[B,  sactive,  H]×[H,  nq,  dh][B,  sactive,  nq,  dh]K=XWk:[B,  sprior+sactive,  H]×[H,  nkv,  dh][B,  sprior+sactive,  nkv,  dh]V=XWv:[B,  sprior+sactive,  H]×[H,  nkv,  dh][B,  sprior+sactive,  nkv,  dh]\begin{aligned} Q &= X \cdot W_q : [B,\; s_{\text{active}},\; H] \times [H,\; n_q,\; d_h] \to [B,\; s_{\text{active}},\; n_q,\; d_h] \\ K &= X \cdot W_k : [B,\; s_{\text{prior}} + s_{\text{active}},\; H] \times [H,\; n_{kv},\; d_h] \to [B,\; s_{\text{prior}} + s_{\text{active}},\; n_{kv},\; d_h] \\ V &= X \cdot W_v : [B,\; s_{\text{prior}} + s_{\text{active}},\; H] \times [H,\; n_{kv},\; d_h] \to [B,\; s_{\text{prior}} + s_{\text{active}},\; n_{kv},\; d_h] \end{aligned}

QQ only depends on sactives_{\text{active}} but KK and VV depend on both sactives_{\text{active}} and spriors_{\text{prior}}. In practice, the Q/K/V projections are typically fused into a single matmul for efficiency. This means without KV caching, the input XX must include all prior tokens — so we end up recomputing KK and VV for every prior token even though their values haven't changed since the last step, and we compute QQ for prior tokens that we immediately discard.

This is why KV caching is such a critical compute optimization. The projection matmul for KK and VV now only operates over sactives_{\text{active}} (typically 1 token during decode) instead of the full sequence. For Qwen3-0.6B at sequence 512, that's a ~512x reduction in the K/V projection: from a [512,1024]×[1024,4096][512, 1024] \times [1024, 4096] matmul down to [1,1024]×[1024,4096][1, 1024] \times [1024, 4096]. We simply cache the prior output projections for KK and VV and update the cache with each forward pass.

With KV caching our QKV projection becomes:

Q=XWq:[B,  sactive,  H]×[H,  nq,  dh][B,  sactive,  nq,  dh]K=XWk:[B,  sactive,  H]×[H,  nkv,  dh][B,  sactive,  nkv,  dh]V=XWv:[B,  sactive,  H]×[H,  nkv,  dh][B,  sactive,  nkv,  dh]\begin{aligned} Q &= X \cdot W_q : [B,\; s_{\text{active}},\; H] \times [H,\; n_q,\; d_h] \to [B,\; s_{\text{active}},\; n_q,\; d_h] \\ K &= X \cdot W_k : [B,\; s_{\text{active}},\; H] \times [H,\; n_{kv},\; d_h] \to [B,\; s_{\text{active}},\; n_{kv},\; d_h] \\ V &= X \cdot W_v : [B,\; s_{\text{active}},\; H] \times [H,\; n_{kv},\; d_h] \to [B,\; s_{\text{active}},\; n_{kv},\; d_h] \end{aligned}

The new KK and VV projections are then concatenated with the cached values from prior steps before computing attention, so the full KK and VV used in the attention formula still span the entire sequence. Our matmul no longer grows with sequence length, but the cache now grows linearly with sequence length. The below table and chart show how this plays out for a few different models

Model HH nqn_q nkvn_{kv} dhd_h QKV without Cache (seq N) QKV with Cache
Qwen3-0.6B 1024 16 8 128 [N,1024]×[1024,4096][N, 1024] \times [1024, 4096] [1,1024]×[1024,4096][1, 1024] \times [1024, 4096]
Gemma 3 27B 4608 32 8 128 [N,4608]×[4608,6144][N, 4608] \times [4608, 6144] [1,4608]×[4608,6144][1, 4608] \times [4608, 6144]
Llama 3 70B 8192 64 8 128 [N,8192]×[8192,10240][N, 8192] \times [8192, 10240] [1,8192]×[8192,10240][1, 8192] \times [8192, 10240]

The output dimension is (nq+2nkv)dh(n_q + 2 \cdot n_{kv}) \cdot d_h, as the Q, K, and V weight matrices stacked into a single projection.

Implementation Details

Each K and V tensor has shape (num_layers, batch, max_seq_len, num_kv_heads, head_dim). We need two of them (one for K, one for V), so the total memory footprint is:

KV Cache Size=2×B×s×L×nkv×dh×sizeof(dtype)\text{KV Cache Size} = 2 \times B \times s \times L \times n_{kv} \times d_h \times \text{sizeof(dtype)}

Where BB is the batch size, ss is the max sequence length, LL is the number of layers, nkvn_{kv} is the number of KV heads, and dhd_h is the head dimension. This comes directly from the shape of the WkW_k and WvW_v projection matrices — each produces an output of size nkv×dhn_{kv} \times d_h per token, and we store that for every token, every layer, every sequence in the batch.

For Qwen3-0.6B with B=1B=1, s=1024s=1024, L=28L=28, nkv=8n_{kv}=8, dh=128d_h=128, fp16: 2×1×1024×28×8×128×22 \times 1 \times 1024 \times 28 \times 8 \times 128 \times 2 \approx 114 MB. For Llama 3 70B with 80 layers, the same calculation yields multiple GBs per sequence.

class KVCache:
    def __init__(self, num_layers, num_kv_heads, head_dim,
                 max_seq_len, batch=1, dtype=torch.float32):
	    self.k = torch.zeros(
            num_layers, batch, max_seq_len, num_kv_heads, head_dim, dtype=dtype
            )
        self.v = torch.zeros(
            num_layers, batch, max_seq_len, num_kv_heads, head_dim, dtype=dtype
            )
        
        self.num_layers = num_layers
        self.pos_id = 0

    def update(self, layer_idx, k, v):
        next_pos = self.pos_id + k.shape[1]
        self.k[layer_idx, :, self.pos_id : next_pos, :, :] = k
        self.v[layer_idx, :, self.pos_id : next_pos, :, :] = v

        if self.num_layers == layer_idx + 1:
            self.pos_id = next_pos

        return (self.k[layer_idx, :, :next_pos, :, :],
                self.v[layer_idx, :, :next_pos, :, :])
4-9 | #3d8fb5 | Pre-allocation
We commit to a `max_seq_len` upfront and pre-allocate the full K and V tensors for all layers. This is a contiguous block of memory — simple and fast, but we pay for the max even if most sequences are shorter.
---
14-17 | #c0956b | Cache update
On each forward pass, we slice the new K and V into the pre-allocated tensors at the current position. This is a simple contiguous write — no recomputation of prior tokens needed.
---
19-20 | #5a9e6f | Position tracking
`pos_id` advances after the last layer processes, keeping all layers in sync. Note this only tracks a single sequence — it can't handle batched sequences with different lengths, which is one motivation for paged attention.
---
22-23 | #7a6bbd | Return full cache
Returns the full K and V up to the current position for attention. The returned slice grows each step — this is the linear memory cost of KV caching.

KV caching trades compute for memory — we no longer recompute KK and VV, but the cache grows linearly with sequence length. For long sequences or large batches, this becomes the dominant memory cost.

Paged Attention

Paged attention addresses memory fragmentation in the KV cache by borrowing from OS paging algorithms, dividing the KV cache into fixed-size pages. It is the hallmark algorithm introduced by vLLM. While paged attention doesn't necessarily reduce the total amount of memory needed for the KV cache, it allows you to use it more efficiently. In practice, flat caches waste a lot of memory because not every sequence uses its full allocation.

Consider a scenario where you have 4 sequences in a batch, and have preallocated 1024 tokens for each sequence's KV cache. Maybe sequence 0 uses all 1024 tokens, sequence 1 uses only 512, sequence 2 uses even less, and sequence 3 also uses about 512 tokens.

That's 4096 token slots allocated but only 2248 used — 45% waste. For Qwen3-0.6B in fp16, each token's KV cache across 28 layers costs ~112KB, so we've wasted ~202MB of the ~448MB we allocated. Scale this to a real serving system with hundreds of concurrent sequences and it adds up fast.

This results in both internal fragmentation – memory preallocated for a sequence's KV cache that never gets used, and external fragmentation – gaps between allocations that are too small to be reused.

Paged Attention addresses this by divvying up the KV cache into a set of fixed-size blocks. Instead of preallocating a flat cache for all sequences, we allocate a shared pool of blocks. Blocks are assigned to a sequence as needed and when a sequence's current block is filled, the next available block is assigned to it.

This creates a new type of memory access. Previously with flat KV each sequence's cache was contiguous in memory. Now a single sequence's cache is distributed across non-contiguous blocks. This requires multiple scattered reads, which is problematic for GPUs and accelerators that are optimized for large coalesced memory access rather than many smaller ones.

Implementation Details

To handle this new blocked KV, we need to introduce a few new classes/concepts to our serving system

  1. Block - data class to represent the block data structure. It holds the block id and block size
  2. Block Pool - maintains the pool of available blocks for the cache manager to use.
  3. KV cache manager - Interface between the model and the block KV cache
  4. BlockKVCache - Our actual implementation of the cache, holds the cache tensors.
@dataclass
class Block:
    size: int
    id: int

class BlockPool:
    def __init__(self, num_blocks: int, block_size: int):
        self.num_blocks = num_blocks
        self.block_size = block_size
        self.free_blocks_queue = [
            Block(self.block_size, i) for i in range(self.num_blocks)
        ]

    def allocate(self) -> Block:
        block = self.free_blocks_queue.pop()
        return block

    def free(self, block: Block):
        self.free_blocks_queue.append(block)
1-4 | #7a6bbd | Block
Minimal dataclass — just a block ID and size. Each block represents a fixed-size chunk of the KV cache tensor that can be assigned to any sequence.
---
6-12 | #3d8fb5 | Pool initialization
Pre-creates all blocks upfront as a list. This is our free list — blocks are popped off when allocated and pushed back when freed. Compare to vLLM's doubly-linked list with LRU eviction.
---
14-16 | #c0956b | allocate()
LIFO allocation — pop() from the end. O(1). Returns a Block to the caller (the KVCacheManager).
---
18-19 | #5a9e6f | free()
Returns a block to the pool. Called when a sequence completes and its blocks can be reused by other sequences.
class BlockMapping(NamedTuple):
    block_ids: torch.Tensor
    offsets: torch.Tensor

class BlockTableEntry(NamedTuple):
    blocks: list[Block]
    num_tokens: int

class KVCacheManager:
    def __init__(self, kv_config, model_config):
        self.block_size = kv_config.block_size
        self.kv_cache = BlockKVCache(
            model_config.num_hidden_layers,
            model_config.num_key_value_heads,
            model_config.head_dim,
            kv_config.num_blocks, kv_config.block_size,
        )
        self.block_pool = BlockPool(
            kv_config.num_blocks, kv_config.block_size)
        self.block_table: dict[int, BlockTableEntry] = {}

    def register_sequence(self, seq_id):
        block = self.block_pool.allocate()
        self.block_table[seq_id] = BlockTableEntry([block], 0)

    def update_block_tables(self, seq_id, num_new_tokens):
        blocks, num_tokens = self.block_table[seq_id]
        block_ids = []
        offsets = []
        for i in range(num_new_tokens):
            pos = num_tokens + i
            offset = pos % self.block_size
            if offset == 0 and pos > 0:
                blocks.append(self.block_pool.allocate())
            block_ids.append(blocks[pos // self.block_size].id)
            offsets.append(offset)
        self.mapping = BlockMapping(
            torch.tensor(block_ids), torch.tensor(offsets))
        self.block_table[seq_id] = BlockTableEntry(
            blocks, num_tokens + num_new_tokens)

    def update(self, layer_idx, seq_id, k, v):
        self.kv_cache.update(layer_idx, self.mapping, k, v)
        full_k, full_v = self.read(layer_idx, seq_id)
        return full_k, full_v

    def read(self, layer_idx, seq_id):
        blocks, num_tokens = self.block_table[seq_id]
        block_ids = torch.tensor([b.id for b in blocks])
        return self.kv_cache.read(layer_idx, block_ids, num_tokens)
1-7 | #7a6bbd | Data structures
`BlockMapping` tells `BlockKVCache` where to scatter/gather when reading or writing to the cache. 

`BlockTableEntry` tracks which blocks belong to a sequence and how many tokens it has.
---
9-20 | #3d8fb5 | Orchestrator init
KVCacheManager owns both the BlockPool and BlockKVCache, plus the block_table dict mapping each sequence to its blocks. This is the single entry point the model uses for all cache operations.
---
22-24 | #c0956b | Register sequence
Called once when a new sequence enters the batch. Allocates its first block and creates a block table entry with 0 tokens.
---
26-40 | #5a9e6f | Build block mapping
The core bookkeeping — called before each forward pass. For each new token, computes which block it belongs to and its offset within that block. When a block fills up (offset == 0), a new block is allocated from the pool. Produces the BlockMapping used by scatter/gather.
---
42-50 | #c47ab3 | Update & read
Called per-layer during the forward pass.
`update` writes new K,V into the cache and reads back the full sequence. `read` gathers block IDs from the table and delegates to BlockKVCache.
class BlockKVCache:
    def __init__(self, num_layers, num_kv_heads, head_dim,
                 num_blocks, block_size, dtype=torch.float32):
        self.k = torch.zeros(num_layers, num_blocks, block_size,
                             num_kv_heads, head_dim, dtype=dtype)
        self.v = torch.zeros(num_layers, num_blocks, block_size,
                             num_kv_heads, head_dim, dtype=dtype)
        self.num_layers = num_layers

    def update(self, layer_idx, block_mapping, k, v):
        self.k[layer_idx].index_put_(
            (block_mapping[0], block_mapping[1]), k)
        self.v[layer_idx].index_put_(
            (block_mapping[0], block_mapping[1]), v)

    def read(self, layer_idx, block_mapping, seq_len):
        full_k = self.k[layer_idx, block_mapping].reshape(
            -1, self.k.shape[-2], self.k.shape[-1])[:seq_len]
        full_v = self.v[layer_idx, block_mapping].reshape(
            -1, self.v.shape[-2], self.v.shape[-1])[:seq_len]
        return full_k, full_v
1-8 | #3d8fb5 | Block-shaped tensors
Unlike the flat KVCache which is shaped:
`(layers, batch, seq_len, ...)`, 
this is:
`(layers, num_blocks, block_size, heads, dim)`. 
The seq_len dimension is replaced by a 2D block grid — tokens live at [block_id][offset].
---
10-14 | #c0956b | Scatter write
index_put_ scatters new K,V into arbitrary (block_id, offset) positions. This is the key operation — tokens from different sequences write to different physical blocks.
---
16-21 | #5a9e6f | Gather read
Gathers K,V from a sequence's scattered blocks back into a contiguous tensor for attention. 

1. Index by block_ids, reshape from (n_blocks, block_size, ...) to (n_blocks*block_size, ...)
2. slice to actual seq_len since the last block may be partially filled.

vLLM Internals

Our implementation covers the core concepts of paged attention, but it's intentionally simplified. vLLM's actual implementation adds several layers of complexity for performance, hybrid attention support, and features like prefix caching. Here's a mapping to help you better orient yourself inside the vLLM code base:

petite-vllm vLLM v1 Notes
BlockPool BlockPool We use a simple list as a stack (LIFO). vLLM uses a doubly-linked list with reference counting and LRU eviction to support prefix caching — blocks that share a common prompt prefix can be reused across sequences.
BlockKVCache (flat tensor) gpu_model_runner.py::kv_caches Same idea — pre-allocated block-shaped tensors on GPU. vLLM creates them during model runner init rather than in a standalone class.
BlockKVCache.update() (index_put_) Triton kernel reshape_and_cache_flash() Same scattered write, but vLLM fuses it into a Triton kernel for GPU throughput. Our index_put_ is the PyTorch equivalent — correct but slower.
KVCacheManager KVCacheCoordinator

SingleTypeKVCacheManager

BlockTable
We have one class doing everything. vLLM splits this into 3 layers — coordinator (top-level scheduling), per-attention-type manager, and block table (slot mapping). Needed for multi-backend support (e.g. different KV cache types per layer).
BlockMapping BlockTable.slot_mapping We build block mappings in Python. vLLM computes slot mappings in a Triton kernel — avoids CPU-GPU sync overhead at scale.
update_block_tables() KVCacheCoordinator.allocate_new_blocks()

BlockTable.compute_slot_mapping()
Same logical flow — allocate blocks, then compute mappings. vLLM does this in the scheduler phase before the model forward pass, same as us.

What Next?

In this post, we took a deeper look at how KV caching is implemented. Starting from recomputing the K and V tensors from scratch, a brief tour through flat caching until arriving at our current destination, Paged Attention.

There's still much to explore with KV caching, with techniques like prefix caching, MLA and hybrid attention schemes, but these are out of scope for petite-vllm for now.

In the next post, we'll dive into continuous batching, where we will allow sequences to enter and exit the batch independently.

Resources

The code for this post is available in the petite-vllm repo here:

I found the following resources helpful while working on this implementation and writeup.