š Implementing the Mamba Model from Scratch
This project focuses on implementing the Mamba model based on the research paper Mamba: Linear-Time Sequence Modeling with Selective State Spaces. The Mamba architecture presents a linear-time alternative to Transformers using Selective State Space Models (SSMs) for efficient long-sequence modeling.
š¬ How Mamba Works
šļø Key Components of the Mamba Model:
- State Space Models (SSMs): Instead of self-attention, Mamba processes sequences using state-space representations to maintain efficiency.
- Gated Selective Update Mechanism: Allows selective updates to hidden states, reducing unnecessary computations.
- Structured Parameterization: Reduces complexity while maintaining expressive power.
Unlike Transformers, Mamba models scale linearly O(N) with sequence length rather than quadratically O(N²), making them ideal for long-context applications.
š ļø Implementing Mamba from Scratch
Below is the full implementation of the Mamba model using PyTorch.
import torch
import torch.nn as nn
import torch.nn.functional as F
def selective_state_update(hidden_state, input_vector, gate, dt):
"""Selective update mechanism for state-space representation."""
return hidden_state * (1 - gate) + input_vector * dt * gate
class MambaCell(nn.Module):
def __init__(self, input_dim, hidden_dim, dt=0.01):
super().__init__()
self.hidden_dim = hidden_dim
self.W_in = nn.Linear(input_dim, hidden_dim)
self.W_gate = nn.Linear(input_dim, hidden_dim)
self.W_out = nn.Linear(hidden_dim, input_dim)
self.dt = dt
def forward(self, x, hidden_state):
"""Forward pass for Mamba Cell."""
input_vector = torch.tanh(self.W_in(x))
gate = torch.sigmoid(self.W_gate(x))
new_hidden = selective_state_update(hidden_state, input_vector, gate, self.dt)
output = self.W_out(new_hidden)
return output, new_hidden
class MambaModel(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers):
super().__init__()
self.layers = nn.ModuleList([MambaCell(input_dim, hidden_dim) for _ in range(num_layers)])
self.hidden_dim = hidden_dim
def forward(self, x):
batch_size, seq_len, _ = x.shape
hidden_states = [torch.zeros(batch_size, self.hidden_dim).to(x.device) for _ in self.layers]
outputs = []
for t in range(seq_len):
input_t = x[:, t, :]
for i, layer in enumerate(self.layers):
input_t, hidden_states[i] = layer(input_t, hidden_states[i])
outputs.append(input_t.unsqueeze(1))
return torch.cat(outputs, dim=1)
š Training the Mamba Model
We train the Mamba model using a synthetic sequential dataset.
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
# Generate synthetic sequence data
def generate_synthetic_data(num_samples=1000, seq_len=50, input_dim=10):
X = np.random.randn(num_samples, seq_len, input_dim).astype(np.float32)
y = np.sum(X, axis=1) # Example task: sum of sequence elements
return torch.tensor(X), torch.tensor(y)
# Load data
X_train, y_train = generate_synthetic_data()
dataset = TensorDataset(X_train, y_train)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# Initialize model
model = MambaModel(input_dim=10, hidden_dim=64, num_layers=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
# Train the model
for epoch in range(10):
total_loss = 0
for X_batch, y_batch in dataloader:
optimizer.zero_grad()
predictions = model(X_batch).sum(dim=1) # Summing along sequence length
loss = criterion(predictions, y_batch)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
š Observations & Future Work
š Advantages of Mamba over Transformers:
- Linear Time Complexity (vs. O(N²) for Transformers).
- Handles Long Sequences Efficiently without quadratic memory requirements.
- Better Interpretability due to explicit state updates.
š Future Improvements:
- Implement more complex gated update mechanisms.
- Train Mamba on real-world NLP tasks (e.g., text generation, classification).
- Compare performance with Transformers and LSTMs.
š„ Get Started
Clone the repository and run the model:
git clone https://github.com/YourUsername/Mamba-Model.git
cd Mamba-Model
pip install -r requirements.txt
python train_mamba.py
š” Contribute: Optimize the state-space mechanisms or test on real-world datasets!
š GitHub Repository