technical
petite-vllm · llm-serving · transformers · pytorch

petite-vllm Part 1: Autoregressive Generation

Building an LLM Serving Engine from Scratch

Part 1 starts things off quite simple: take a pretrained Qwen3 0.6B model, implement its architecture from scratch in PyTorch, load the weights, and generate text one token at a time. The benchmark time is very slow at just 1.3 tokens/sec on CPU, so this is the baseline that well be optimizing on soon when we add improvements such as KV caching and PagedAttention. All code is on the Github branch lesson/01-generation.

In this post I'll give a quick rundown of each component of the transformer architecture that we implemented. I don't provide any in-depth explanations of the components here, and assume the reader has some basic familiarity with MLP, Attention, etc.

The Transformer Block

Qwen 3 is a decoder-only transformer (which is what all modern LLMs are) which consists of a stack of 28 identical layers. As with all modern LLMs, theres an embedding layer at the beginning which takes the input token ids and returns the vectors from the models vocab. At the end is the LM head, a final linear layer that projects back from hidden dimension back to tokens, outputting the logits (the models raw prediction) for each token.

Here's the full model:

class Qwen3ForCausalLM(nn.Module):
    def __init__(self, config: Qwen3Config):
        super().__init__()
        self.config = config
        self.embed_tokens = VocabEmbedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList(
            [Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
        )
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
        self.lm_head = LMHead(
            config.hidden_size,
            config.vocab_size,
            weight=self.embed_tokens.weight if config.tie_word_embeddings else None,
        )

    def forward(self, input_ids, positions):
        x = self.embed_tokens(input_ids)  # go from vocab_size -> hidden_size
        for layer in self.layers:
            x = layer(x, positions)
        x = self.norm(x)
        return self.lm_head(x)  # go back from hidden_size -> vocab_size

Qwen3-0.6B has 28 of these layers with a hidden size of 1024. Let's walk through each component.

RMSNorm

RMSNorm is the normalization method of choice for most modern transformers. Just normalize by the root-mean-square and scale by a learnable weight:

def forward(self, x):
    return x * self.weight / (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()

The mean(-1, keepdim=True) normalizes per-token along the last dimension (the hidden dimension), and keepdim=True preserves the dimension for broadcasting.

Attention with GQA and RoPE

The attention layer is where most of the interesting design decisions live. Qwen3 uses three features worth understanding: grouped-query attention, QK normalization, and rotary position embeddings.

Grouped-query attention

Standard multi-head attention gives each head its own Q, K, and V projections. Grouped-query attention (GQA) uses fewer KV heads than Q heads. Qwen3 0.6B has 16 Q heads but only 8 KV heads. Each KV head is shared by 2 Q heads, meaning the number of KV groups is 8, and the KV group size is 2. (16 Q heads / 8 KV heads = 8 groups of size 2).

This saves memory and compute on the KV projections (important for the KV cache in part 2) while maintaining the expressiveness of having many Q heads. In practice, we project K and V with the smaller head count, then repeat_interleave them to match the Q head count before computing attention:

k = torch.repeat_interleave(k, self.num_kv_groups, dim=2)
v = torch.repeat_interleave(v, self.num_kv_groups, dim=2)

QK normalization

Qwen3 applies RMSNorm to the Q and K vectors per-head, after projection but before position encoding. The norm operates on head_dim (128), not hidden_size (1024), since each head is normalized independently. This stabilizes attention scores by preventing Q and K magnitudes from growing too large in deeper layers.

Rotary position embeddings (RoPE)

The original transformer adds position information by summing a position embedding with the token embedding. RoPE encodes position by rotating the Q and K vectors by an angle proportional to position. The attention dot product then naturally depends on the relative distance between tokens, not their absolute positions.

The implementation precomputes an inverse frequency vector:

inv_freq[i] = 1 / (base ^ (2i / head_dim))

For each position, we compute cos and sin of the angles and apply the rotation. An important implementation detail: there are two valid ways to pair dimensions for rotation. The "even/odd" approach pairs dimensions (0,1), (2,3), etc. The "half-split" approach (which HF and Qwen3 use) pairs dimension i with i + head_dim/2. These produce different results, and you must match what the model was trained with. I initially implemented even/odd and got degenerate repetition loops until I switched to half-split:

def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_emb(q, k, cos, sin):
    cos = cos.unsqueeze(0).unsqueeze(2)  # (1, seq_len, 1, head_dim)
    sin = sin.unsqueeze(0).unsqueeze(2)

    q_rotated = (q * cos) + (rotate_half(q) * sin)
    k_rotated = (k * cos) + (rotate_half(k) * sin)

    return q_rotated, k_rotated

The MLP: SwiGLU

The MLP in each transformer layer uses SwiGLU, a gated activation that's become standard in modern LLMs (Llama, Mistral, Qwen, Gemma all use it). A single linear layer projects to 2x the intermediate size. The output is split in half — one half passes through SiLU (the "gate"), then is multiplied element-wise with the other half (the "value"):

class SiluAndMul(nn.Module):
    def forward(self, x):
        gate, value = torch.chunk(x, 2, dim=-1)
        return F.silu(gate) * value

We fuse the gate and up projections into one gate_up_proj layer because one large matrix multiply is faster than two smaller ones. However this means our weight naming differs from HuggingFace's which keeps them split, so we need to handle that during weight loading.

Weight tying

The embedding table is a matrix of shape (vocab_size, hidden_size) , where each row is the learned vector for one token. The LM head needs a matrix of the same shape to project hidden states back to vocabulary scores. Rather than storing two separate 150K x 1024 matrices, Qwen3 shares the same weight for both. This is called weight tying.

The implementation uses F.linear(x, self.weight) in the LM head rather than nn.Linear, which would create its own separate parameter, allowing both modules to share the same tensor in memory.

Weight loading

As alluded to above, our model's parameter names don't exactly match HuggingFace's and required the following remappings.

  1. Strip the model. prefix — HF wraps the transformer layers under a model attribute, so keys look like model.layers.0.self_attn.q_proj.weight. Ours are just layers.0.self_attn.q_proj.weight.

  2. Fuse gate and up projections — HF stores gate_proj and up_proj as separate weights. We concatenate them into gate_up_proj: torch.cat([gate_proj, up_proj], dim=0).

  3. Skip tied weightslm_head.weight is tied to embed_tokens.weight, so we don't load it separately. The inv_freq buffers in our RoPE are precomputed from config, not learned.

remapped = {}
for key, value in hf_state.items():
    if key == "lm_head.weight":
        continue
    new_key = key.removeprefix("model.")
    if "gate_proj.weight" in new_key:
        up_key = key.replace("gate_proj", "up_proj")
        remapped[new_key.replace("gate_proj", "gate_up_proj")] = torch.cat(
            [value, hf_state[up_key]], dim=0
        )
        continue
    if "up_proj.weight" in new_key:
        continue
    remapped[new_key] = value

self.model.load_state_dict(remapped, strict=False)

The generation loop

Without a KV cache, autoregressive generation is extremely wasteful. For each new token, we re-run the entire sequence through the model, recomputing the attention scores for tokens which have already been seen and computed:

for tok_id in range(max_tokens):
    positions = torch.arange(toks.shape[1]) # get positions based on current_seq_len

    all_logits = self.model.forward(toks, positions)
    last_logits = all_logits[:, -1, :]
    new_tkn = sample(last_logits, temperature, top_k)
    toks = torch.cat([toks, new_tkn.unsqueeze(0)], dim=-1)

This is O(n) work per token and O(n^2) total for generating n tokens, because the sequence fed to the model grows by one each step. The KV cache (covered in part 2) fixes this by caching the intermediate key and value computations so we only process the new token each step.

Sampling

The sampler converts raw logits (unnormalized scores over the vocabulary) into a token selection. With temperature=0, it's pure greedy — just take the argmax. With temperature > 0:

  1. Divide logits by temperature (higher temperature = flatter distribution = more randomness)
  2. Keep only the top-k highest scores, mask everything else to -inf
  3. Softmax to get probabilities
  4. Sample from the distribution with torch.multinomial
def sample(logits, temperature=1.0, top_k=50):
    if temperature == 0.0:
        return logits.argmax(dim=-1)
    logits = logits / temperature
    vals, _ = torch.topk(logits, top_k, dim=-1)
    smallest = vals[:, -1:]
    masked = torch.where(logits >= smallest, logits, -float("inf"))
    probs = torch.softmax(masked, -1)
    return torch.multinomial(probs, num_samples=1).squeeze(-1)

Result

petite-vllm % python examples/01_basic_generation.py
Loading weights: 100%|████████████████████████████████████| 311/311 [00:00<00:00, 16933.37it/s]
1.3 tokens/sec
The capital of France is:  Paris. The capital of Italy is Rome. The capital of Spain is Madrid. The capital of China

It works! The model generates coherent text, but 1.3 tokens/sec is extremely slow. We'll speed this up in part 2 by implementing the KV Cache: cache the K and V tensors from previous positions so we don't recompute them.


Debugging notes

Two bugs that produced identical symptoms (degenerate repetition loops) but had different root causes:

  1. Wrong RoPE base frequencyrope_theta should be 1,000,000 for Qwen3, but I defaulted to 10,000. With the wrong base, rotation frequencies are 100x too fast, and positions that should look similar to the model look wildly different. Attention scores become meaningless.

  2. Wrong RoPE dimension pairing — I initially used even/odd interleaving (pairing dimensions 0,1 then 2,3) instead of half-split (pairing 0,64 then 1,65). Both are mathematically valid RoPE implementations, but the model was trained with half-split.

One more note about debugging small models: check your prompt! A typo in my initial prompt ("captial" instead of "capital") caused the 0.6B model to fall into a repetition loop. Larger models handle noisy input gracefully, but small models don't have the capacity to recover.