Chapter 6 ยท CORE

RETRO

๐Ÿ“„ 06_retro.md ๐Ÿท Core

Chapter 6: RETRO

In the previous chapter, Mixture of Experts (MoE), we learned how to make models smarter and faster by creating "specialists" inside the brain.

But even the smartest expert has a limit: Memory. If you ask a standard GPT model, "Who won the game yesterday?", it cannot answer because its training data is months or years old.

To fix this, we don't need a bigger brain; we need an Open Book. We need RETRO (Retrieval-Enhanced Transformer).

Motivation: The Open-Book Exam

Imagine you are taking a very hard history exam.

Use Case: Factual Question Answering or updating knowledge without retraining.

What is RETRO?

RETRO is a modification of the standard GPT (Decoder-only) architecture. It adds an external memory system.

  1. The Database: A massive collection of text chunks (e.g., Wikipedia, News, Code).
  2. The Retriever: A search engine that finds chunks relevant to your current input.
  3. The Integration: The model uses Cross-Attention (borrowed from T5 (Encoder-Decoder)) to "read" these chunks while generating text.
graph LR Input[User Input] --> Search[Retriever] DB[(Database)] --> Search Search --> Notes[Retrieved Notes] Notes --> Model[RETRO Model] Input --> Model Model --> Output[Answer] style DB fill:#eee,stroke:#333 style Model fill:#f9f,stroke:#333

Using Megatron-LM: pretrain_retro.py

Training a RETRO model requires a database index and a specific script.

A Simple Training Command

The command looks similar to GPT, but you must point it to your retrieval data.

python pretrain_retro.py \
    --num-layers 12 \
    --hidden-size 768 \
    --retro-work-dir /path/to/retrieval_data \
    --retro-add-retriever \
    --seq-length 1024 \
    --data-path my_text_data

Key Arguments:

  1. --retro-work-dir: The folder containing your "external brain" (the pre-processed database of text chunks).
  2. --retro-add-retriever: Tells the model, "Don't just guess! Look up information first."

Understanding the Input and Output

  1. Input: "The capital of France is..."
  2. Retrieval: The system searches the database and finds: "Paris is the capital and most populous city of France."
  3. Processing: The model reads the input and the retrieved text simultaneously.
  4. Output: "Paris."

Under the Hood: The Internal Flow

How does the model merge its own thoughts with the external notes? It uses a Chunked Cross-Attention mechanism.

The input sequence is split into small chunks. For each chunk, the model retrieves "Neighbors" (relevant external text).

sequenceDiagram participant U as User Input participant R as Retriever participant E as Encoder (Read Notes) participant D as Decoder (Write Text) U->>R: "The capital of France..." R->>R: Search Database R->>E: Send Found Text ("Paris is...") Note over E: Encodes "Notes" into vectors E->>D: Send Key/Value Pairs U->>D: Send "The capital of France..." Note over D: Look at Input AND Notes D->>D: Generate "Paris"

Diving into the Code: retro_model.py

The logic resides in megatron/core/models/retro. The RETRO model is unique because it has a small encoder inside a large decoder.

1. The Wrapper

The RetroModel initializes a standard GPT structure but adds an encoder to process the retrieved text.

# megatron/core/models/retro/model.py

class RetroModel(MegatronModule):
    def __init__(self, config, ...):
        super().__init__()
        
        # 1. The Retrieval Encoder
        # This reads the external "notes" found in the database
        self.retrieval_encoder = RetroEncoder(config, ...)

        # 2. The Main Decoder (GPT)
        # This writes the text, occasionally looking at the encoder
        self.decoder = RetroDecoder(config, ...)

2. The Forward Pass

This is the most critical part. We don't just pass input_ids. We also pass retrieved_ids.

    def forward(self, input_ids, retrieved_ids, ...):
        # Step 1: Read the external notes
        # The encoder turns retrieved text into math (embeddings)
        # These are the "keys" and "values" for attention
        encoder_output = self.retrieval_encoder(retrieved_ids)

        # Step 2: Generate text
        # The decoder uses the user input AND the encoder_output
        decoder_output = self.decoder(
            input_ids,
            encoder_output=encoder_output # <--- Looking at the notes
        )
        
        return decoder_output

3. Inside the Decoder Layer

In a standard GPT, layers only look at the past (Self-Attention). In RETRO, specific layers (e.g., every 3rd layer) also look at the retrieved data.

class RetroDecoderLayer(MegatronModule):
    def forward(self, hidden_states, encoder_output, ...):
        # 1. Standard Self-Attention (Look at what I just wrote)
        hidden_states = self.self_attention(hidden_states)

        # 2. Cross-Attention (Look at the external notes)
        # Only happens if this layer is configured for retrieval
        if self.use_retro_attention:
            hidden_states = self.cross_attention(
                hidden_states, 
                encoder_output # <--- The bridge to external data
            )
            
        # 3. Feed Forward (Process information)
        return self.mlp(hidden_states)

Beginner Explanation: Think of self.cross_attention as the moment the student looks up from their exam paper to glance at the textbook. They take that information and combine it with what they have already written.

Summary

In this chapter, you learned:

  1. RETRO allows models to look up information from an external database instead of memorizing everything.
  2. It uses a Retriever to find relevant text chunks ("Neighbors").
  3. It uses Cross-Attention to incorporate that information into the generation process.
  4. This architecture is excellent for tasks requiring up-to-date facts or massive knowledge bases.

We have now covered almost every way to process text: Writers, Readers, Translators, Memory-Savers, Experts, and Researchers (RETRO).

But the world isn't just text. What about Images? How do we teach a Transformer to see?

Next Chapter: CLIP / SigLIP / InternViT (Vision)


Generated by Code IQ