This is the official PyTorch implementation for the paper "Gated Associative Memory: A Parallel O(N) Architecture for Efficient Sequence Modeling" by Rishiraj Acharya.
This repository contains the source code for the GAM model (gam_model.py
), the main benchmarking script to reproduce experimental results (benchmark.py
), and the scaling analysis script (scale.py
).
The Transformer architecture, while powerful, scales quadratically with sequence length (O(N²)), creating a bottleneck for long-context tasks. We propose the Gated Associative Memory (GAM) network, a novel, fully parallel architecture for sequence modeling that exhibits linear complexity (O(N)). The GAM block replaces self-attention with two parallel pathways: a causal convolution to capture local context and a parallel associative memory retrieval mechanism for global, content-based patterns. These pathways are dynamically fused via a learned gating mechanism. Our experiments on WikiText-2 and TinyStories show that GAM is consistently faster and achieves superior or competitive perplexity compared to standard Transformer and Mamba baselines, establishing it as a promising and efficient alternative for sequence modeling.
- Linear Complexity: O(N) time and memory complexity, enabling efficient processing of very long sequences.
- Fully Parallelizable: The architecture contains no recurrent components, allowing for maximum utilization of modern hardware like GPUs.
- Dual-Pathway Design: Explicitly decomposes context modeling into a Local Pathway (Causal Convolution) for syntax and a Global Pathway (Associative Memory) for semantics.
- Dynamic Gating: A learned mechanism dynamically fuses local and global information on a per-token basis, adding expressivity.
- Superior Performance: Outperforms strong Transformer and Mamba baselines in both speed and perplexity on benchmark datasets.
The core of our model is the GAMBlock
, which replaces the traditional multi-head self-attention layer.
Figure 1: The input
x
is normalized and then processed by two parallel pathways. The Local Pathway uses a Causal 1D Convolution, while the Global Pathway queries a learnable Memory Bank. The outputs are dynamically combined by a learned gate and added to the residual connection, followed by a standard FFN.
-
Clone the repository:
git clone https://github.com/rishiraj/gam.git cd gam
-
Create a virtual environment (recommended):
python -m venv venv source venv/bin/activate # On Windows, use `venv\Scripts\activate`
-
Install the required dependencies: A
requirements.txt
file is provided for easy setup.pip install -r requirements.txt
This repository contains three main Python scripts.
This script runs the main comparative experiments from the paper (Section 4), training and evaluating GAM, Transformer, and Mamba models on WikiText-2 and TinyStories.
Setup:
The script uses wandb
for logging. Please log in to your account and configure the WANDB_ENTITY
variable in benchmark.py
with your username or team name.
Run the full benchmark:
python benchmark.py
The script will sequentially:
- Download and prepare the WikiText-2 dataset.
- Train and evaluate GAM, Transformer, and Mamba on WikiText-2.
- Download and prepare the TinyStories dataset.
- Train and evaluate GAM and Transformer on TinyStories.
All results, including training loss, validation perplexity, and epoch times, will be logged to your Weights & Biases project.
This script reproduces the targeted scaling analysis from Section 5.2 of the paper, measuring the forward+backward time and peak memory usage of a single GAM block vs. a Transformer block as sequence length increases.
Run the scaling analysis:
python scale.py
The script will output a table of results directly to the console and save plots (compute_time.png
and peak_memory.png
) to the current directory, visually demonstrating GAM's linear scaling compared to the Transformer's quadratic growth.
You can easily import and use the GAM_Model
in your own projects. The model implementation is self-contained in gam_model.py
.
Here is a minimal example:
import torch
from gam_model import GAM_Model, ModelConfig
# 1. Define the model configuration
config = ModelConfig(
VOCAB_SIZE=10000,
BLOCK_SIZE=256,
N_EMBED=512,
N_LAYER=6,
NUM_MEMORY_SLOTS=512,
CONV_KERNEL_SIZE=3
)
# 2. Initialize the model
model = GAM_Model(config)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters.")
# 3. Create dummy data and run a forward pass
dummy_input = torch.randint(0, config.VOCAB_SIZE, (8, 128), device=device)
logits, loss = model(dummy_input, targets=dummy_input)
print(f"Output logits shape: {logits.shape}")
print(f"Calculated loss: {loss.item()}")
Our experiments confirm that GAM is both faster and more accurate than the baselines.
Dataset | Model | Params | Avg. Time / Epoch (s) | Val. Loss | Val. PPL |
---|---|---|---|---|---|
WikiText-2 | Transformer | 24.2 M | 131.9 s | 6.8233 | 918.99 |
Mamba | 20.5 M | 127.1 s | 6.9251 | 1017.54 | |
GAM (Ours) | 22.6 M | 117.2 s | 6.7828 | 882.57 | |
TinyStories | Transformer | 24.2 M | 671.6 s | 3.1591 | 23.55 |
GAM (Ours) | 22.6 M | 601.4 s | 3.1418 | 23.15 |
GAM demonstrates clear linear (O(N)
) scaling in time and memory, whereas the Transformer shows quadratic (O(N²)
) growth and fails with Out-of-Memory (OOM) errors on longer sequences.
Sequence Length | Time (ms) - GAM | Time (ms) - Transformer | Memory (MB) - GAM | Memory (MB) - Transformer |
---|---|---|---|---|
256 | 8.97 | 8.90 | 179.42 | 216.03 |
512 | 13.09 | 23.86 | 325.48 | 552.98 |
1024 | 25.86 | 74.19 | 617.60 | 1964.79 |
2048 | 51.94 | 279.37 | 1201.85 | 7483.92 |
4096 | 105.03 | Failed (OOM) | 2370.35 | Failed (OOM) |
8192 | 217.30 | Failed (OOM) | 4707.35 | Failed (OOM) |
Our ablation study on WikiText-2 confirms that all components of the GAM block are essential for peak performance.
Model Configuration | Gating? | Local? | Global? | Params | Val. PPL |
---|---|---|---|---|---|
GAM (Full) | ✅ | ✅ | ✅ | 22.6 M | 900.84 |
GAM (Global Only) | ❌ | ❌ | ✅ | 19.4 M | 905.45 |
GAM (Sum Fusion) | ❌ | ✅ | ✅ | 19.4 M | 942.59 |
GAM (Local Only) | ❌ | ✅ | ❌ | 17.9 M | 944.70 |
Key takeaways:
- Gating is crucial for effective fusion: Simple summation performs poorly.
- Global associative memory is the primary driver of performance: The
Global Only
model is very strong. - Local context is complementary and necessary: The full model still outperforms
Global Only
, whileLocal Only
is insufficient.
If you find this work useful in your research, please consider citing our paper:
@misc{acharya2025gatedassociativememoryparallel,
title={Gated Associative Memory: A Parallel O(N) Architecture for Efficient Sequence Modeling},
author={Rishiraj Acharya},
year={2025},
eprint={2509.00605},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2509.00605},
}