Welcome to the final chapter of LLMs from Scratch!
In the previous chapter, Chapter 7: Mixture of Experts (MoE), we learned how to make our model smarter and more scalable by giving it a "team of specialists."
However, we still face a major bottleneck. As the model generates longer and longer text (like writing an essay), it gets slower and slower. The first word is instant, but the 1000th word takes forever.
In this chapter, we will solve this using Inference Optimization with a technique called the KV Cache (Key-Value Cache).
To understand why LLMs get slow, imagine you are writing a speech.
The "Standard" Way (Without Cache):
By the time you are at the 100th word, you are re-reading 99 words just to add one more. This is extremely inefficient!
The "Optimized" Way (With Cache):
The KV Cache is that memory. It stops the model from re-calculating the past over and over again.
Recall Chapter 3: Attention Mechanisms (Self & Grouped Query). We learned that every token produces three vectors:
When we generate the word "sat" in the sentence "The cat sat", the word "The" doesn't change. Its Key and Value vectors are exactly the same as they were a moment ago.
Instead of recalculating $K$ and $V$ for "The" and "cat" every single step, we save them in a Cache (a temporary storage in GPU memory).
We need to modify our MultiHeadAttention class to hold this memory.
In the __init__ method, we create placeholders (buffers) to store the Keys and Values.
# Inside MultiHeadAttention.__init__
# We register buffers to hold the history.
# They start as None (Empty).
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
Explanation: register_buffer tells PyTorch to store these tensors inside the model, but not to treat them as trainable weights. They are just temporary storage.
Now, when the model runs, we check if we should use the cache.
# Inside MultiHeadAttention.forward
if use_cache:
# If we have history, glue the new Keys to the old Keys
if self.cache_k is not None:
keys = torch.cat([self.cache_k, keys], dim=1)
values = torch.cat([self.cache_v, values], dim=1)
# Save the updated history back to the buffer
self.cache_k = keys
self.cache_v = values
Explanation: torch.cat (concatenate) is like gluing two lists together. We take the history we saved (self.cache_k) and glue the new token's key (keys) to the end of it.
In Chapter 4: The GPT Architecture (Transformer Block), we learned about Positional Embeddings. The model needs to know if a word is the 1st or the 100th.
If we only feed the model one word (the newest one), the model might think it's the 1st word. We need to tell it: "Hey, this is actually word #100."
# Inside GPTModel.forward
if use_cache:
# We track a variable 'current_pos'
# Start creating IDs from current_pos (e.g., 100)
pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=device)
# Update position for next time
self.current_pos += seq_len
This is where the magic happens. We change how we call the model.
In the standard loop, we passed the entire sequence every time. In the cached loop, we pass only the new token.
Old Way (Slow):
# Pass the FULL history every time
logits = model(full_sequence_indices)
next_token = logits[:, -1].argmax()
New Way (Fast):
# Pass ONLY the newest token
# The model looks in its cache for the rest!
logits = model(new_token, use_cache=True)
next_token = logits[:, -1].argmax()
Just like clearing your browser cache, we must clear the KV Cache when we start a brand new prompt.
def reset_kv_cache(self):
for block in self.trf_blocks:
block.att.reset_cache()
self.current_pos = 0
Let's look at what happens under the hood when we generate the phrase "AI is cool".
We must process the first word normally to fill the cache.
Here is the optimization. We input only "is".
Notice that the model did not have to recalculate the math for "AI". It just retrieved it.
You might ask: "Is there a downside?"
Yes. Memory (VRAM).
Storing the Keys and Values for a long conversation takes up space on your Graphics Card (GPU).
For very long documents (like summarizing a book), the cache can get so big it fills up your GPU memory. This is why techniques like Grouped Query Attention (discussed in Chapter 6: Modern Model Variations (Llama & Qwen)) are importantโthey shrink the size of the KV Cache.
Congratulations! You have reached the end of the tutorial.
In this chapter, we:
Course Summary: You started with raw text and built a complete Large Language Model from scratch.
You now possess the foundational knowledge to understand papers about GPT-4, Claude, Llama, and whatever comes next. Happy coding!
Generated by Code IQ