0

Implementing Mamba Model from Scratch in PyTorch

Building the Mamba model from scratch based on the research paper https://arxiv.org/pdf/2312.00752 using PyTorch.

šŸ Implementing the Mamba Model from Scratch

This project focuses on implementing the Mamba model based on the research paper Mamba: Linear-Time Sequence Modeling with Selective State Spaces. The Mamba architecture presents a linear-time alternative to Transformers using Selective State Space Models (SSMs) for efficient long-sequence modeling.


šŸ”¬ How Mamba Works

šŸ—ļø Key Components of the Mamba Model:

  1. State Space Models (SSMs): Instead of self-attention, Mamba processes sequences using state-space representations to maintain efficiency.
  2. Gated Selective Update Mechanism: Allows selective updates to hidden states, reducing unnecessary computations.
  3. Structured Parameterization: Reduces complexity while maintaining expressive power.

Unlike Transformers, Mamba models scale linearly O(N) with sequence length rather than quadratically O(N²), making them ideal for long-context applications.


šŸ› ļø Implementing Mamba from Scratch

Below is the full implementation of the Mamba model using PyTorch.

import torch
import torch.nn as nn
import torch.nn.functional as F
 
def selective_state_update(hidden_state, input_vector, gate, dt):
    """Selective update mechanism for state-space representation."""
    return hidden_state * (1 - gate) + input_vector * dt * gate
 
class MambaCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, dt=0.01):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.W_in = nn.Linear(input_dim, hidden_dim)
        self.W_gate = nn.Linear(input_dim, hidden_dim)
        self.W_out = nn.Linear(hidden_dim, input_dim)
        self.dt = dt
 
    def forward(self, x, hidden_state):
        """Forward pass for Mamba Cell."""
        input_vector = torch.tanh(self.W_in(x))
        gate = torch.sigmoid(self.W_gate(x))
        new_hidden = selective_state_update(hidden_state, input_vector, gate, self.dt)
        output = self.W_out(new_hidden)
        return output, new_hidden
 
class MambaModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([MambaCell(input_dim, hidden_dim) for _ in range(num_layers)])
        self.hidden_dim = hidden_dim
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        hidden_states = [torch.zeros(batch_size, self.hidden_dim).to(x.device) for _ in self.layers]
        outputs = []
        for t in range(seq_len):
            input_t = x[:, t, :]
            for i, layer in enumerate(self.layers):
                input_t, hidden_states[i] = layer(input_t, hidden_states[i])
            outputs.append(input_t.unsqueeze(1))
        return torch.cat(outputs, dim=1)

šŸ“Š Training the Mamba Model

We train the Mamba model using a synthetic sequential dataset.

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
 
# Generate synthetic sequence data
def generate_synthetic_data(num_samples=1000, seq_len=50, input_dim=10):
    X = np.random.randn(num_samples, seq_len, input_dim).astype(np.float32)
    y = np.sum(X, axis=1)  # Example task: sum of sequence elements
    return torch.tensor(X), torch.tensor(y)
 
# Load data
X_train, y_train = generate_synthetic_data()
dataset = TensorDataset(X_train, y_train)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
 
# Initialize model
model = MambaModel(input_dim=10, hidden_dim=64, num_layers=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
 
# Train the model
for epoch in range(10):
    total_loss = 0
    for X_batch, y_batch in dataloader:
        optimizer.zero_grad()
        predictions = model(X_batch).sum(dim=1)  # Summing along sequence length
        loss = criterion(predictions, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

šŸ” Observations & Future Work

šŸ† Advantages of Mamba over Transformers:

  • Linear Time Complexity (vs. O(N²) for Transformers).
  • Handles Long Sequences Efficiently without quadratic memory requirements.
  • Better Interpretability due to explicit state updates.

šŸš€ Future Improvements:

  • Implement more complex gated update mechanisms.
  • Train Mamba on real-world NLP tasks (e.g., text generation, classification).
  • Compare performance with Transformers and LSTMs.

šŸ“„ Get Started

Clone the repository and run the model:

git clone https://github.com/YourUsername/Mamba-Model.git
cd Mamba-Model
pip install -r requirements.txt
python train_mamba.py

šŸ’” Contribute: Optimize the state-space mechanisms or test on real-world datasets!

šŸ”— GitHub Repository