Bringing Memory Caching to State Space Models

mongoobi, Mar 2026


The paper "Memory Caching: RNNs with Growing Memory" introduces a mechanism for giving recurrent models explicit long-range memory by caching activations at segment boundaries and retrieving them via gated attention. They tested it on linear attention and Titans.

They never tested it on selective state space models.

This post covers the architecture design and implementation of MC-Mamba: memory caching applied to Mamba for autoregressive music generation. This is, as far as I can tell, the first attempt at this combination.


Why Memory Caching for Music?

Music has a specific structural problem: you need both local sequential coherence (the next beat follows from the last) AND long-range structural recall (the chorus returns after the bridge, the theme recurs in variations).

Mamba handles local dynamics well — it's basically a learned convolutional-recurrent model. But its fixed-size hidden state means long-range callbacks degrade. The baseline experiments showed Mamba-Attention hybrids matching transformers, but we want to push further.

Memory caching is a natural fit: cache "what happened at structurally important moments" and retrieve it when needed.


The Key Design Decision: Output Activation Caching

The original MC paper caches intermediate recurrent states. For Mamba, this doesn't work cleanly:

  • The mamba-ssm CUDA kernel doesn't expose intermediate SSM hidden states $$(\mathbf{h}t \in \mathbb{R}^{d{inner} \times d_{state}})$$ during the training-time parallel scan
  • SSM states live in a different dimensional space $$(d_{inner} \times d_{state})$$ than model activations $$(d_{model})$$
  • Extracting mid-kernel states would break the hardware-aware parallelism that makes Mamba fast

Solution: cache block output activations instead.

At every segment boundary (positions $$S, 2S, 3S, \ldots$$), we store the Mamba block's output activation $$\mathbf{h}i \in \mathbb{R}^{d{model}}$$. These live in the same space as inputs, so retrieval is a simple gated residual addition.

This turns out to be arguably better than raw state caching: 1. Output activations are post-nonlinearity, post-gating — they're the "refined summary" of what the block computed 2. Cache entries are $$(B, d_{model})$$ instead of $$(B, d_{inner}, d_{state})$$ — much smaller 3. Full CUDA kernel parallelism preserved during training 4. Gradients flow through cache entries naturally


GRM Gating (Equations 8-10 from the paper)

Each MC-enhanced block does:

  1. Standard forward: $$\mathbf{x}{out} = \text{MambaBlock}(\mathbf{x}{in})$$
  2. Extract boundaries: $$\mathbf{h}i = \mathbf{x}{out}[:, iS-1, :]$$ at segment boundaries
  3. Segment means: $$\mathbf{m}i = \text{mean}(\mathbf{x}{in}[\text{segment}_i])$$ as keys
  4. Query: $$\mathbf{u}t = W_u \cdot \mathbf{x}{in,t}$$
  5. Gate: $$\gamma_t^{(i)} = \text{softmax}\left(\frac{\langle \mathbf{u}_t, \mathbf{m}_i \rangle}{\sqrt{d}}\right)$$ over current + cached segments
  6. Output: $$\mathbf{y}t = \gamma_t^{(\text{current})} \cdot \mathbf{x}{out,t} + \sum_i \gamma_t^{(\text{cached}_i)} \cdot \mathbf{h}_i$$

The only new parameters per layer: $$W_u \in \mathbb{R}^{d_{model} \times d_{model}}$$.


Three Architecture Variants

1. MC-Mamba (pure)

All 20 layers get memory caching. ~140M params (128M base + 11.8M MC overhead, ~9%).

2. MC-Linear-Attention

Replace Mamba blocks with linear attention (ELU+1 feature map, FLA kernel backend). All layers get MC. ~41M params at d_model=640, 12 layers.

3. Hybrid MC-LA 1:3

1 MC-enhanced linear attention layer per 3 standard linear attention layers. Only 25% of layers carry MC overhead. ~49M params, MC overhead drops to 2.5%.


The Stability War

The first MC-Linear-Attention runs exploded:

FloatingPointError: WARNING unstable grad_norm at step 13: inf

Step 13. Thirteen steps in.

The debugging trail: - Attempt 1: retrieval_scale=1.0, lr=3e-4 → inf at step 13 - Attempt 2: retrieval_scale=0.15, lr=1e-4 → inf at step 89. Progress, but not enough. - Attempt 3: Switched to fail_on_large_grad_norm=0 (warn but don't die), fail_on_nonfinite_grad_norm=1 (die on actual inf/nan), grad_norm_warn_threshold=1e6. This lets transient spikes pass while catching true divergence. - Attempt 4 (current): lr=1e-4, retrieval_scale=1.0, larger dataset (25k samples vs 10k). Stable. The larger dataset smoothed the gradients enough.

Lesson: MC introduces a retrieval residual path with softmax gating. Early in training, when the cache is near-empty and representations are random, the gate weights are poorly conditioned. More data → more diverse gradient signal → smoother optimization landscape.


Architecture Summary

Variant Params MC Layers MC Overhead Throughput
MC-Linear-Attention 41M 12/12 12.0% ~120k tok/s
Hybrid MC-LA 1:3 49M 3/12 2.5% ~225k tok/s

What We're Tracking

GRM Entropy — Shannon entropy of the gating distribution $$\gamma$$. High entropy = uniform gating (not selective). Low entropy = the model is learning to attend to specific cached segments. We want this to decrease during training.

Cache Size — Number of cached segment entries. Starts at 8 (sequence length / segment size) and can grow with longer sequences.

Current training results: MC-LA vs Hybrid MC-LA 1:3