In the previous chapter, Mixture of Experts (MoE), we learned how to make models smarter and faster by creating "specialists" inside the brain.
But even the smartest expert has a limit: Memory. If you ask a standard GPT model, "Who won the game yesterday?", it cannot answer because its training data is months or years old.
To fix this, we don't need a bigger brain; we need an Open Book. We need RETRO (Retrieval-Enhanced Transformer).
Imagine you are taking a very hard history exam.
Use Case: Factual Question Answering or updating knowledge without retraining.
RETRO is a modification of the standard GPT (Decoder-only) architecture. It adds an external memory system.
pretrain_retro.pyTraining a RETRO model requires a database index and a specific script.
The command looks similar to GPT, but you must point it to your retrieval data.
python pretrain_retro.py \
--num-layers 12 \
--hidden-size 768 \
--retro-work-dir /path/to/retrieval_data \
--retro-add-retriever \
--seq-length 1024 \
--data-path my_text_data
Key Arguments:
--retro-work-dir: The folder containing your "external brain" (the pre-processed database of text chunks).--retro-add-retriever: Tells the model, "Don't just guess! Look up information first."How does the model merge its own thoughts with the external notes? It uses a Chunked Cross-Attention mechanism.
The input sequence is split into small chunks. For each chunk, the model retrieves "Neighbors" (relevant external text).
retro_model.py
The logic resides in megatron/core/models/retro. The RETRO model is unique because it has a small encoder inside a large decoder.
The RetroModel initializes a standard GPT structure but adds an encoder to process the retrieved text.
# megatron/core/models/retro/model.py
class RetroModel(MegatronModule):
def __init__(self, config, ...):
super().__init__()
# 1. The Retrieval Encoder
# This reads the external "notes" found in the database
self.retrieval_encoder = RetroEncoder(config, ...)
# 2. The Main Decoder (GPT)
# This writes the text, occasionally looking at the encoder
self.decoder = RetroDecoder(config, ...)
This is the most critical part. We don't just pass input_ids. We also pass retrieved_ids.
def forward(self, input_ids, retrieved_ids, ...):
# Step 1: Read the external notes
# The encoder turns retrieved text into math (embeddings)
# These are the "keys" and "values" for attention
encoder_output = self.retrieval_encoder(retrieved_ids)
# Step 2: Generate text
# The decoder uses the user input AND the encoder_output
decoder_output = self.decoder(
input_ids,
encoder_output=encoder_output # <--- Looking at the notes
)
return decoder_output
In a standard GPT, layers only look at the past (Self-Attention). In RETRO, specific layers (e.g., every 3rd layer) also look at the retrieved data.
class RetroDecoderLayer(MegatronModule):
def forward(self, hidden_states, encoder_output, ...):
# 1. Standard Self-Attention (Look at what I just wrote)
hidden_states = self.self_attention(hidden_states)
# 2. Cross-Attention (Look at the external notes)
# Only happens if this layer is configured for retrieval
if self.use_retro_attention:
hidden_states = self.cross_attention(
hidden_states,
encoder_output # <--- The bridge to external data
)
# 3. Feed Forward (Process information)
return self.mlp(hidden_states)
Beginner Explanation: Think of self.cross_attention as the moment the student looks up from their exam paper to glance at the textbook. They take that information and combine it with what they have already written.
In this chapter, you learned:
We have now covered almost every way to process text: Writers, Readers, Translators, Memory-Savers, Experts, and Researchers (RETRO).
But the world isn't just text. What about Images? How do we teach a Transformer to see?
Next Chapter: CLIP / SigLIP / InternViT (Vision)
Generated by Code IQ