Bringing Memory Caching to State Space Models
mongoobi, Mar 2026
The paper "Memory Caching: RNNs with Growing Memory" introduces a mechanism for giving recurrent models explicit long-range memory by caching activations at segment boundaries and retrieving them via gated attention. They tested it on linear attention and Titans.
They never tested it on selective state space models.
This post covers the architecture design and implementation of MC-Mamba: memory caching applied to Mamba for autoregressive music generation. This is, as far as I can tell, the first attempt at this combination.
Why Memory Caching for Music?
Music has a specific structural problem: you need both local sequential coherence (the next beat follows from the last) AND long-range structural recall (the chorus returns after the bridge, the theme recurs in variations).
Mamba handles local dynamics well — it's basically a learned convolutional-recurrent model. But its fixed-size hidden state means long-range callbacks degrade. The baseline experiments showed Mamba-Attention hybrids matching transformers, but we want to push further.
Memory caching is a natural fit: cache "what happened at structurally important moments" and retrieve it when needed.
The Key Design Decision: Output Activation Caching
The original MC paper caches intermediate recurrent states. For Mamba, this doesn't work cleanly:
- The
mamba-ssmCUDA kernel doesn't expose intermediate SSM hidden states $$(\mathbf{h}t \in \mathbb{R}^{d{inner} \times d_{state}})$$ during the training-time parallel scan - SSM states live in a different dimensional space $$(d_{inner} \times d_{state})$$ than model activations $$(d_{model})$$
- Extracting mid-kernel states would break the hardware-aware parallelism that makes Mamba fast
Solution: cache block output activations instead.
At every segment boundary (positions $$S, 2S, 3S, \ldots$$), we store the Mamba block's output activation $$\mathbf{h}i \in \mathbb{R}^{d{model}}$$. These live in the same space as inputs, so retrieval is a simple gated residual addition.
This turns out to be arguably better than raw state caching: 1. Output activations are post-nonlinearity, post-gating — they're the "refined summary" of what the block computed 2. Cache entries are $$(B, d_{model})$$ instead of $$(B, d_{inner}, d_{state})$$ — much smaller 3. Full CUDA kernel parallelism preserved during training 4. Gradients flow through cache entries naturally
GRM Gating (Equations 8-10 from the paper)
Each MC-enhanced block does:
- Standard forward: $$\mathbf{x}{out} = \text{MambaBlock}(\mathbf{x}{in})$$
- Extract boundaries: $$\mathbf{h}i = \mathbf{x}{out}[:, iS-1, :]$$ at segment boundaries
- Segment means: $$\mathbf{m}i = \text{mean}(\mathbf{x}{in}[\text{segment}_i])$$ as keys
- Query: $$\mathbf{u}t = W_u \cdot \mathbf{x}{in,t}$$
- Gate: $$\gamma_t^{(i)} = \text{softmax}\left(\frac{\langle \mathbf{u}_t, \mathbf{m}_i \rangle}{\sqrt{d}}\right)$$ over current + cached segments
- Output: $$\mathbf{y}t = \gamma_t^{(\text{current})} \cdot \mathbf{x}{out,t} + \sum_i \gamma_t^{(\text{cached}_i)} \cdot \mathbf{h}_i$$
The only new parameters per layer: $$W_u \in \mathbb{R}^{d_{model} \times d_{model}}$$.
Three Architecture Variants
1. MC-Mamba (pure)
All 20 layers get memory caching. ~140M params (128M base + 11.8M MC overhead, ~9%).
2. MC-Linear-Attention
Replace Mamba blocks with linear attention (ELU+1 feature map, FLA kernel backend). All layers get MC. ~41M params at d_model=640, 12 layers.
3. Hybrid MC-LA 1:3
1 MC-enhanced linear attention layer per 3 standard linear attention layers. Only 25% of layers carry MC overhead. ~49M params, MC overhead drops to 2.5%.
The Stability War
The first MC-Linear-Attention runs exploded:
FloatingPointError: WARNING unstable grad_norm at step 13: inf
Step 13. Thirteen steps in.
The debugging trail:
- Attempt 1: retrieval_scale=1.0, lr=3e-4 → inf at step 13
- Attempt 2: retrieval_scale=0.15, lr=1e-4 → inf at step 89. Progress, but not enough.
- Attempt 3: Switched to fail_on_large_grad_norm=0 (warn but don't die), fail_on_nonfinite_grad_norm=1 (die on actual inf/nan), grad_norm_warn_threshold=1e6. This lets transient spikes pass while catching true divergence.
- Attempt 4 (current): lr=1e-4, retrieval_scale=1.0, larger dataset (25k samples vs 10k). Stable. The larger dataset smoothed the gradients enough.
Lesson: MC introduces a retrieval residual path with softmax gating. Early in training, when the cache is near-empty and representations are random, the gate weights are poorly conditioned. More data → more diverse gradient signal → smoother optimization landscape.
Architecture Summary
| Variant | Params | MC Layers | MC Overhead | Throughput |
|---|---|---|---|---|
| MC-Linear-Attention | 41M | 12/12 | 12.0% | ~120k tok/s |
| Hybrid MC-LA 1:3 | 49M | 3/12 | 2.5% | ~225k tok/s |
What We're Tracking
GRM Entropy — Shannon entropy of the gating distribution $$\gamma$$. High entropy = uniform gating (not selective). Low entropy = the model is learning to attend to specific cached segments. We want this to decrease during training.
Cache Size — Number of cached segment entries. Starts at 8 (sequence length / segment size) and can grow with longer sequences.