Chapter 4 · CORE

Mamba (State Space Model)

📄 04_mamba__state_space_model_.md 🏷 Core

Chapter 4: Mamba (State Space Model)

In the previous chapters, we explored the "Transformer" family:

  1. GPT (Decoder-only) for writing.
  2. BERT (Encoder-only) for understanding.
  3. T5 (Encoder-Decoder) for translation.

All three share the same backbone: Attention. While Attention is powerful, it has a major weakness. As the text gets longer, Attention gets much slower and uses much more memory.

Enter Mamba.

Motivation: The "Infinite" Book Reader

Imagine you are reading a 1,000-page novel.

Use Case: Processing extremely long sequences (like whole DNA strands or entire code repositories).

What is a State Space Model (SSM)?

Mamba is not a Transformer; it is a State Space Model.

To understand it, think of a Conveyor Belt.

  1. Input: Raw materials (Tokens) come in.
  2. The State (Hidden State): This is the machine on the belt. It holds the "current status" of the product.
  3. Update Rule: When a new token arrives, the machine updates its internal status based on the current status + new input.
  4. Output: The machine produces a prediction.
graph LR subgraph "Transformer (Attention)" A[Word 1] --- B[Word 2] B --- C[Word 3] A --- C style A fill:#f9f,stroke:#333 style B fill:#f9f,stroke:#333 style C fill:#f9f,stroke:#333 end subgraph "Mamba (SSM)" S1((State)) --> S2((State)) S2 --> S3((State)) I1[Word 1] --> S1 I2[Word 2] --> S2 I3[Word 3] --> S3 style S1 fill:#bbf,stroke:#333 style S2 fill:#bbf,stroke:#333 style S3 fill:#bbf,stroke:#333 end

Key Benefit: Mamba has Linear Complexity. If you double the text length, it takes exactly twice as long (not 4x as long like Transformers).

Using Megatron-LM: pretrain_mamba.py

Megatron-LM includes a specialized script for this architecture.

A Simple Training Command

Training Mamba looks similar to GPT, but under the hood, the math is completely different.

python pretrain_mamba.py \
    --num-layers 24 \
    --hidden-size 1024 \
    --ssm-state-size 16 \
    --seq-length 4096 \
    --data-path my_long_text_data

New Arguments:

  1. --ssm-state-size 16: This is the size of the "mental summary" or "memory" the model keeps for each feature. A larger state captures more details but is slower to compute.
  2. --seq-length: Mamba can handle much larger sequence lengths than standard Transformers effectively.

Under the Hood: The Internal Flow

How does Mamba process data? It uses a mechanism called Selective Scan. It decides what information to keep in its "State" and what to forget (like a filter).

sequenceDiagram participant I as Input (Token) participant M as Mixer (SSM Layer) participant S as Hidden State participant O as Output Note over I, O: Step 1: Processing "The" I->>M: Send "The" S->>M: Send Old State (Empty) M->>S: Update State (Remember "The") M->>O: Predict Next Note over I, O: Step 2: Processing "Quick" I->>M: Send "Quick" S->>M: Send State (Has "The") M->>S: Update State (Remember "The Quick") M->>O: Predict Next

Diving into the Code: mamba_model.py

The implementation lives in megatron/core/models/mamba/mamba_model.py.

1. The Wrapper

The MambaModel replaces the Transformer Decoder with a stack of Mamba Layers (often called "Mixer Layers").

# megatron/core/models/mamba/mamba_model.py

class MambaModel(MegatronModule):
    def __init__(self, config, ...):
        super().__init__()
        
        # 1. Embeddings: Same as GPT (Words -> Vectors)
        self.embedding = LanguageModelEmbedding(config, ...)

        # 2. The Backbone: A stack of Mamba Layers
        # Instead of 'TransformerBlock', we use 'MambaStack'
        self.decoder = MambaStack(config, ...)

2. The Mamba Layer (The Mixer)

Inside MambaStack, we don't have Attention Heads. We have the SSM Mixer. This is where the "State" magic happens.

class MambaLayer(MegatronModule):
    def forward(self, hidden_states, ...):
        # 1. Project inputs to higher dimensions
        # This prepares the data for the state machine
        xz = self.in_proj(hidden_states)

        # 2. Run the SSM (Selective Scan)
        # This updates the 'State' and calculates output linearly
        out = mamba_inner_fn(xz, self.conv1d_weight, ...)

        # 3. Project back to normal size
        return self.out_proj(out)

Beginner Note: mamba_inner_fn is a highly optimized function (written in CUDA) that performs the "running summary" math incredibly fast.

3. Why "Selective"?

Standard State Space Models remember everything. Mamba is Selective.

In the code, the model generates specific parameters (often called B, C, and Delta) that act like gates.

This allows Mamba to perform reasoning tasks (like "Copy the first word of the sentence to the end") that older SSMs struggled with.

Summary

In this chapter, you learned:

  1. Mamba is a State Space Model (SSM), not a Transformer.
  2. It solves the memory bottleneck of Attention, allowing for extremely long sequences.
  3. It works by maintaining a Recurrent State (a running summary) rather than looking back at history.
  4. You use pretrain_mamba.py to train it.

We have now covered the major architectural "shapes" (Encoder, Decoder, Encoder-Decoder, and SSM).

But what if you want to make your model massive—smart enough to know everything—but you don't want to pay the cost of computing every neuron for every word? You need a way to only use the parts of the brain relevant to the current topic.

Next Chapter: Mixture of Experts (MoE)


Generated by Code IQ