skip to content

Chunkwise linear attention and DeltaNet (Part 3)

Linear Attention Part 3 — Chunkwise kernels, the WY representation, and the MQAR benchmark

1. Recap

In Part 1 we derived from softmax attention the recurrence relation St=St1At+btktS_t = S_{t-1} A_t + b_t k_t^\top using a finite-dim kernel and state SS with capacity ndkn \lesssim d_k, and further motivated and derived the delta rule and gating. In Part 2 we expanded this concept into four complementary directions (memory architecture, objective, optimizer, retention) and showed how every finite-sized memory model fits into the picture. In this part we cover the engineering part of training them performantly on GPUs, and then benchmark them to test our theoretical picture empirically.

2. Chunkwise

Looking back at where we came from, we motivated linear attention for the regime TdT \gg d. For example, TT reaches ~1M for the newest models (Claude Opus 4.7).

The recurrence

St=St1+vtkt,ot=StqtS_t = S_{t-1} + v_t k_t^\top, \qquad o_t = S_t q_t

is nice for inference as we pay just O(d2)O(d^2) per step instead of O(T2)O(T^2) over the sequence, but training is now the problem. There we have the full sequence but the recurrence still forces us to compute every StS_t in order. On modern accelerators built around big matmul operations this wastes a lot of compute and shows up as low MFU. For just one sequence this means thousands of sequential kernel launches, each one a tiny matmul where the launch overhead dominates the actual work. What we want is the opposite extreme: one big matmul over the whole sequence at once.

The full parallel form for linear attention is possible and can fully utilize the GPU when TT is manageable. Pretraining at 4–8K is fine, finetuning at 32K or 128K still works, but ideally we’d train on 256K–1M sequences to stay close to inference length. At that scale the score matrix doesn’t fit, and we need a way to chunk along TT.

Time and peak memory of pure parallel and pure recurrent linear attention across T from 4K to 131K on an A100.
Pure parallel (linear attention) vs pure recurrent (DeltaNet) at production head_dim=128 on A100. Parallel OOMs at T=131K; recurrent is slow at every scale.

The bigger problem: the full parallel form only exists for linear attention, not for DeltaNet.

Linear attention

Starting with linear attention. Let’s split our sequence of length TT into chunks of size CC, with Qi,KiRC×dkQ_i, K_i \in \mathbb{R}^{C \times d_k} and ViRC×dvV_i \in \mathbb{R}^{C \times d_v} the query/key/value matrices for chunk ii (covering positions iCiC to (i+1)C1(i+1)C - 1). The state SiRdv×dkS_i \in \mathbb{R}^{d_v \times d_k} is the starting state of chunk ii, before any of chunk ii‘s information is in SiS_i.

Leaning on softmax attention and Part 1, we can compute the output inside a chunk in parallel form:

Oiintra=((QiKi)M)Vi.O_i^{\text{intra}} = \bigl((Q_i K_i^\top) \odot M\bigr)\, V_i.

But this only sees writes inside the current chunk. We add the readout from all previous chunks by reading from SiS_i:

Oiprev=QiSi.O_i^{\text{prev}} = Q_i S_i^\top.

What’s left is the next state, Si+1S_{i+1}, which is just chunk ii‘s contribution accumulated into SiS_i:

Si+1=Si+ViKi.S_{i+1} = S_i + V_i^\top K_i.

Final chunk output: Oi=Oiprev+OiintraO_i = O_i^{\text{prev}} + O_i^{\text{intra}}. The whole algorithm is sequential over T/CT/C chunks (the SS hand-off), but the work inside each chunk is mostly matmuls, which makes it GPU-friendly.

Total cost: T/CT/C chunks with three matmuls per chunk of size Cd2+C2d+Cd2Cd^2 + C^2 d + Cd^2, giving O(Td2+TCd)O(Td^2 + TCd) overall. This is more FLOPs than the recurrent form, but we amortize the wall-clock through big matmuls inside each chunk and stay viable at sequence lengths where the parallel form’s score matrix wouldn’t fit.

Same time and memory comparison with chunkwise C=64 and C=512 added; chunkwise stays viable where parallel OOMs.
Same comparison with chunkwise C=64 and C=512 added. Chunkwise stays viable where parallel OOMs at T≥131K, and is close to parallel speed at smaller T.

DeltaNet

Recall DeltaNet’s recurrence:

St=St1(Iβtktkt)+βtvtkt.S_t = S_{t-1}(I - \beta_t k_t k_t^\top) + \beta_t v_t k_t^\top.

Write AtIβtktktA_t \equiv I - \beta_t k_t k_t^\top for the rank-1 transition. Unrolling within chunk ii, the state at position iC+tiC + t is

SiC+t=SiCAiC+1AiC+2AiC+t+s=1tβiC+sviC+skiC+sAiC+s+1AiC+t.S_{iC+t} = S_{iC} \cdot A_{iC+1} A_{iC+2} \cdots A_{iC+t} + \sum_{s=1}^{t} \beta_{iC+s}\, v_{iC+s} k_{iC+s}^\top \cdot A_{iC+s+1} \cdots A_{iC+t}.

Two things change. The cross-chunk hand-off is no longer additive: the carry SiC=SiS_{iC} = S_i has to pass through the product of CC rank-1 perturbations AiC+1AiC+CA_{iC+1} \cdots A_{iC+C}. And each write at position iC+siC + s inside the chunk gets downscaled by the tail product of the AA‘s after it.

If we materialize each tail product as a d×dd \times d matrix the chunkwise cost goes to O(TCd2)O(T C d^2), a factor of CC worse than just running DeltaNet recurrently. Speedup gone. We need a way to represent the product of CC rank-1 perturbations in a smarter way.

3. WY representation

Each AtA_t is identity minus rank-1, so a product of CC such factors is identity minus rank-C\le C:

A1A2AC=IWY,W,YRC×dk.A_1 A_2 \cdots A_C = I - W^\top Y, \qquad W, Y \in \mathbb{R}^{C \times d_k}.

This is the WY representation (Bischof & Van Loan, 1987) The WY Representation for Products of Householder Matrices Bischof, Van Loan · SIAM J. Sci. Stat. Comput. 8 , originally for products of Householder reflectors in QR. The win: we never have to materialize the CC different tail products As+1AtA_{s+1} \cdots A_t explicitly, and applying IWYI - W^\top Y to anything is two matrix-vector products of cost CdkCd_k without ever forming a dk×dkd_k \times d_k matrix.

The closed form for WW. We can derive WW and YY inductively, peeling one AtA_t at a time off a known WY representation of the prefix.

Base case (C=1C = 1). Match A1=Iβ1k1k1A_1 = I - \beta_1 k_1 k_1^\top against IWYI - W^\top Y. With W,YR1×dkW, Y \in \mathbb{R}^{1 \times d_k}, the natural choice is w1=β1k1w_1 = \beta_1 k_1 and y1=k1y_1 = k_1.

Inductive step. Assume A1At1=IW<tY<tA_1 \cdots A_{t-1} = I - W_{<t}^\top Y_{<t}, with YY‘s rows being the keys. Multiply by AtA_t:

A1At=(IW<tY<t)(Iβtktkt)=IW<tY<tβt((A1At1)kt)kt.A_1 \cdots A_t = (I - W_{<t}^\top Y_{<t})(I - \beta_t k_t k_t^\top) = I - W_{<t}^\top Y_{<t} - \beta_t \bigl((A_1 \cdots A_{t-1}) k_t\bigr) k_t^\top.

This matches IWtYtI - W_{\le t}^\top Y_{\le t} exactly when we set yt=kty_t = k_t and

wt=βt(A1At1)kt.w_t = \beta_t \cdot (A_1 \cdots A_{t-1})\, k_t.

wtw_t is βtkt\beta_t k_t with all prior AA‘s already applied.” So YY‘s rows are just the chunk’s keys, Y=KY = K, and WW is built from this closed form.

Unfolding wtw_t into a recursion. We don’t want to actually compute A1At1A_1 \cdots A_{t-1} as a full matrix. But it’s already in WY form by induction, so we can apply it cheaply:

(A1At1)kt=(IW<tY<t)kt=ktW<t(Y<tkt).(A_1 \cdots A_{t-1})\, k_t = (I - W_{<t}^\top Y_{<t})\, k_t = k_t - W_{<t}^\top (Y_{<t}\, k_t).

Each entry of the vector Y<tktY_{<t} k_t is just ksktk_s^\top k_t (since YY‘s rows are the keys), so

W<t(Y<tkt)=s<t(kskt)ws.W_{<t}^\top (Y_{<t}\, k_t) = \sum_{s<t} (k_s^\top k_t)\, w_s.

Plugging back:

wt=βt(kts<t(kskt)ws).w_t = \beta_t\Bigl(k_t - \sum_{s < t}(k_s^\top k_t)\, w_s\Bigr).

The ksktk_s^\top k_t entries form a C×CC \times C matrix, the chunk’s key-gram KrKrK_r K_r^\top. Different matrix from the intra-chunk attention’s QrKrQ_r K_r^\top, but the same shape and the same O(C2dk)O(C^2 d_k) cost to compute.

Stack the recursion as a triangular system: TrRC×CT_r \in \mathbb{R}^{C \times C} strictly lower triangular with (Tr)ts=βt(kskt)(T_r)_{ts} = \beta_t (k_s^\top k_t) for s<ts < t. Then

(I+Tr)Wr=diag(β)Kr.(I + T_r)\, W_r = \mathrm{diag}(\beta)\, K_r.

Unit-triangular forward substitution on a C×CC \times C system, applied across the dkd_k output columns. Cost O(C2dk)O(C^2 d_k), the same order as the intra-chunk attention block. WY is essentially free.

This gives us the chunk’s prefix product in closed form: A1AC=IWrKrA_1 \cdots A_C = I - W_r^\top K_r. Applying it to anything is two matrix products against WrW_r and KrK_r, never the dk×dkd_k \times d_k matrix itself.

The twin matrix UrU_r. That handled the inter-chunk hand-off through WW. The chunk also contributes its own writes, jβjvjkjAj+1AC\sum_j \beta_j v_j k_j^\top A_{j+1} \cdots A_C, and each jj has a different tail product. We can use the same trick: define UrRC×dvU_r \in \mathbb{R}^{C \times d_v} with rows

ut=βt(vts<t(kskt)us),u_t = \beta_t\Bigl(v_t - \sum_{s<t}(k_s^\top k_t)\, u_s\Bigr),

giving the chunk’s write contribution as UrKrU_r^\top K_r. Same coefficient matrix as WrW_r, different right-hand side:

(I+Tr)Ur=diag(β)Vr.(I + T_r)\, U_r = \mathrm{diag}(\beta)\, V_r.

Two triangular solves with shared TrT_r. Each is forward substitution on a C×CC \times C system: row 1 of the output equals row 1 of the right-hand side; row tt subtracts off s<tTts(row s)\sum_{s<t} T_{ts} \cdot (\text{row } s). Sequential in tt, parallel across the output columns (dkd_k for WrW_r, dvd_v for UrU_r). For typical C256C \le 256 this is a tiny system, handled by dedicated trsm kernels on GPU.

Final chunk update.

S(r+1)=S(r)(S(r)Wr)Kr+UrKr.S^{(r+1)} = S^{(r)} - \bigl(S^{(r)} W_r^\top\bigr) K_r + U_r^\top K_r.

Three matmuls of SS- or VV-shaped objects against KrK_r, plus the two triangular solves for WrW_r and UrU_r. Same asymptotic cost as linear-attention chunkwise; only a constant-factor overhead for the rank-1-perturbation transition.

The chunk’s outputs OrO_r assemble in two pieces, same as linear attention. The intra-chunk (within-chunk) readout swaps VrV_r for UrU_r in §2’s formula:

Orintra=((QrKr)M)Ur.O_r^{\text{intra}} = \bigl((Q_r K_r^\top) \odot M\bigr)\, U_r.

The inter-chunk readout uses WY in the same way, since each query needs the partial prefix product A1AtA_1 \cdots A_t applied before reading from S(r)S^{(r)}. Final: Or=Orinter+OrprevO_r = O_r^{\text{inter}} + O_r^{\text{prev}}.

MLP states (TTT, Titans).

WY relies on each AtA_t being a rank-1 perturbation of identity; products of such factors stay low-rank. For MLP-based memory, θt=θt1βtL(θt1)\theta_t = \theta_{t-1} - \beta_t \nabla L(\theta_{t-1}) has no AtA_t at all, so there’s no rank-1 perturbation algebra to lean on.

TTT (Sun et al., 2024) Learning to (Learn at Test Time): RNNs with Expressive Hidden States Sun, Li, Geng, Hua, Wang, Zhao, Liu, Hardt, Chen, Pan, Lin, Wang, Han, Guestrin · 2024 arXiv:2407.04620 and Titans (Behrouz et al., 2024) Titans: Learning to Memorize at Test Time Behrouz, Zhong, Mirrokni · 2024 arXiv:2501.00663 dodge by giving up exactness inside the chunk: evaluate all CC gradients at the start-of-chunk state θ(r)\theta^{(r)} and average them into one update. The chunk becomes a mini-batch SGD step with the inner state frozen during the chunk. What this costs is the fine-grained delta dynamics: two tokens in the same chunk that write to overlapping addresses get their writes side-by-side instead of the second one rewriting the first. Chunk size CC becomes a mini-batch knob for the inner optimizer: bigger CC means more parallelism but coarser updates.

KDA (Moonshot AI, 2025) Kimi Linear: An Expressive, Efficient Attention Architecture Moonshot AI · 2025 arXiv:2510.26692 goes back to a structured-but-linear state (DPLR) precisely so it can keep an exact chunkwise kernel. The price is its own custom kernel: the rank-1 piece’s left and right vectors differ (αtkt\alpha_t \odot k_t vs ktk_t), which breaks DeltaNet’s exact WY form, but the algorithmic shape stays close.

4. MQAR

The previous parts were about associative memory: a fixed-size state SS holds key-value pairs, and the model reads them back at query time. Softmax attention does this perfectly by never compressing the KV cache. Fixed-state recurrences trade exact recall for O(d2)O(d^2) per step. We want a benchmark that measures how good a model can do associative recall.

Multi-Query Associative Recall (MQAR), introduced by (Arora et al., 2023) Zoology: Measuring and Improving Recall in Efficient Language Models Arora, Eyuboglu, Timalsina, Johnson, Poli, Zou, Rudra, Ré · 2023 arXiv:2312.04927 (Zoology), does that. The model is shown a prefix of NN key-value pairs, then asked to recall the value for each key when it reappears later in the sequence. This requires storing many associations, and retrieving them across distance filled with distracting noise.

The sequence has two halves. First a contiguous block of key-value pairs k1v1k2v2kNvNk_1 v_1 k_2 v_2 \ldots k_N v_N, then a longer query region where each of the NN keys reappears once at a random position, padded with noise tokens. At each query position the model has to emit the associated value. A concrete example with N=2N=2, vocab =16=16, T=16T=16:

An MQAR example with N=2, vocab=16, T=16.

query region
role
v₁
v₂
·
·
·
·
·
·
·
·
·
·
pos
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
tokens
8
7
x
x
x
x
x
x
x
x
x
x
target
7
8
key value / target noise query
KV prefix with pairs (k₁,v₁)=(2,8) and (k₂,v₂)=(4,7), then the query region with query keys and noise. Loss is taken only at the two query positions.

Some design details of MQAR:

  1. Disjoint vocab for keys and values. Keys are drawn from the first half of the dictionary [1,V/2)[1, V/2), values from the second [V/2,V)[V/2, V), with token 0 reserved as a noise placeholder. Role (key vs value) is readable off token identity, so the task isolates content matching from role inference.
  2. Queries scattered through noise, not interleaved with their answers. The naive layout [k1v1kNvNq1a1q2a2][k_1 v_1 \ldots k_N v_N\, q_1 a_1 q_2 a_2 \ldots] lets the model learn the shortcut “copy whatever came right after the last query.” Scattering queries among noise removes that shortcut.
  3. The target at a query position is the value, but the next input token is noise. At query position PP, targets[P] = v and mask[P] = 1, but tokens[P+1] is random noise. The value vv is never fed back as input, only scored as a lookup target.

Here are the accuracies of naive baselines:

  • 1/V1/V — random pick from the full vocabulary (≈ 0.0040.004 at V=256V=256).
  • 1/NKV1/N_{KV} — random pick from the prefix’s value set (≈ 0.250.25 at NKV=4N_{KV}=4, ≈ 0.030.03 at NKV=32N_{KV}=32). This is a usual failure mode where the model collapses to one output value.
  • 1\approx 1 — actual content-based retrieval.

Weak models that do not have the capacity to solve MQAR will plateau at 1/NKV1/N_{KV}, having learned to emit a value-shaped token but not the right one.

5. Implementation details

Some implementation details were necessary to get the architectures to solve MQAR.

Short causal conv before QQ, KK, VV. Before each recurrent block we pass the per-position qtq_t, ktk_t, vtv_t vectors through a depthwise causal 1D conv with kernel size 4. This is actually necessary for linear attention and DeltaNet to solve MQAR. Zoology explains this as a context split: the recurrence gives us global context, but we lack the local context to answer “is the previous token a query key that we need to read out?”. Looking at KK‘s conv we see it usually learns a lookback.

Gating parameterization. Following Mamba2 (Dao & Gu, 2024) Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality Dao, Gu · ICML 2024 arXiv:2405.21060 for the scalar (or diagonal) gate αt\alpha_t:

αt=exp ⁣(softplus(Δ)σ(Wαxt+bα)),Δ10 at init.\alpha_t = \exp\!\big(-\mathrm{softplus}(\Delta) \cdot \sigma(W_\alpha x_t + b_\alpha)\big), \qquad \Delta \approx -10 \text{ at init.}

Bounded in (0,1](0, 1] by construction (no NaN risk from sign flips inside a recurrent product), and starts at αt1\alpha_t \approx 1 so the model defaults to “remember everything” and only learns to forget where useful.

L2-normalize kk after the feature map. The rank-1 transition At=IβtktktA_t = I - \beta_t k_t k_t^\top has its only non-unit eigenvalue along ktk_t, equal to 1βtkt21 - \beta_t \|k_t\|^2. If kt\|k_t\| varies freely across positions, the effective overwrite strength varies with it. L2-normalizing kk post-feature-map pins that eigenvalue to 1βt1 - \beta_t, putting it cleanly in [1βt,1][1-\beta_t, 1] and decoupling the learned βt\beta_t from key norm.

To produce the results we used the parallel form of LA and the recurrent form for DeltaNet. For production kernels of everything in §2 and §3 — fused chunkwise, WY recurrence, conv, gating — see FLA (github.com/sustcsonglin/flash-linear-attention), which has all the architectures in this post implemented in Triton.

6. Results

We ran two minimal experiments. The principle is the smallest setting where the mechanism is visible enough to expose the architectural difference.

Capacity ceiling. Pure linear attention’s hard cap is ndkn \lesssim d_k (Part 1, §3). We pick dk=16d_k = 16 and use NKV{4,32}N_{KV} \in \{4, 32\}: one setting below the ceiling, one at 2×2\times above, with vocab =256=256, T=128T = 128, two layers, four heads.

MQAR accuracy for Linear Attention vs DeltaNet at N_KV in {4, 32} with d_k=16. LA collapses to ~1/N_KV at N_KV=32; DN holds at 0.77.
MQAR capacity at d_k=16. Below the ceiling both solve; at 2× the ceiling LA collapses toward the 1/N_KV baseline, DN degrades gracefully.

Below the ceiling with NKV=4N_{KV}=4 both solve cleanly. At 2×2\times the ceiling, LA collapses to 1/NKV\approx 1/N_{KV} — the “random pick from the value set” baseline. DN stays at 0.77 with the same theoretical capacity, but because its delta rule uses the same dkd_k slots more cleanly, its accuracy degrades more slowly.

Retention. Gating only buys something when there is something to decay. We hold NKV=4N_{KV} = 4 fixed (well below capacity) and increase the sequence size with T{64,512}T \in \{64, 512\}, comparing DN and Gated DN at vocab =1024= 1024.

MQAR accuracy for DeltaNet vs Gated DeltaNet at fixed N_KV=4 and T in {64, 512}. Tied at T=64; Gated DN opens a small gap at T=512.
Retention sweep at fixed N_KV=4. Tied at T=64 (no noise to decay); Gated DN opens a small gap at T=512.

At T=64T=64 both score the same: there is so little post-prefix noise that the gate never needs to activate. At T=512T=512 the gap opens: Gated DN’s αt<1\alpha_t < 1 on the noise tokens shrinks the accumulating cross-talk, and the model holds onto its prefix writes a few points longer.

7. Open questions

Some questions that came up during my research on this:

  1. How do you know what to memorize beforehand? Softmax has the luxury of never throwing away context. So we can always attend to the past if it contains useful information. Any finite-sized model will have to make decisions. Even harder: Every model here writes to memory based on the current (kt,vt)(k_t, v_t) pair. How does it know, at time tt, what’s going to be useful later in the sequence? The outer-loop training has access to the downstream loss and presumably teaches the projections WkW_k, WvW_v to pre-emphasize useful patterns, but a principled analysis would be interesting.

  2. Exp kernel under float32 should be finite. Float32 has 223\sim 2^{23} distinguishable values. At what truncation order NN of exp(qk)n=0N(qk)n/n!\exp(q^\top k) \approx \sum_{n=0}^{N} (q^\top k)^n / n! does the residual fall below float32 precision for typical qk\|q^\top k\|? If small NN suffices, “softmax” and “polynomial linear attention” are empirically the same model in float32, which would soften the “infinite kernel” framing of Part 1.

  3. What about LSTM? The MIRAS axes were derived to organize the modern crop, but the LSTM cell with input/forget/output gates is also a fixed-state recurrence with content-dependent gating.

  4. Distillation of softmax attention. Take a softmax-trained model, train a fixed-state inner model per layer to mimic that layer’s attention, then swap in. This would enable us to train a full-parallel softmax (cheap), and then serve with the recurrence (cheap-er, linear at inference). The open question is how well does the distillation work, and whether there is a principled way to decide which layers to leave as softmax.

References

  1. Katharopoulos, Vyas, Pappas, Fleuret (2020). Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. ICML 2020. arXiv:2006.16236
  2. Yang, Wang, Zhang, Shen, Kim (2024). Parallelizing Linear Transformers with the Delta Rule over Sequence Length. NeurIPS 2024. arXiv:2406.06484
  3. Dao, Gu (2024). Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. ICML 2024. arXiv:2405.21060
  4. Yang, Kautz, Hatamizadeh (2024). Gated Delta Networks: Improving Mamba2 with Delta Rule. arXiv:2412.06464
  5. Sun, Li, Geng, Hua, Wang, Zhao, Liu, Hardt, Chen, Pan, Lin, Wang, Han, Guestrin (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv:2407.04620
  6. Behrouz, Zhong, Mirrokni (2024). Titans: Learning to Memorize at Test Time. arXiv:2501.00663
  7. Moonshot AI (2025). Kimi Linear: An Expressive, Efficient Attention Architecture. arXiv:2510.26692
  8. Arora, Eyuboglu, Timalsina, Johnson, Poli, Zou, Rudra, Ré (2023). Zoology: Measuring and Improving Recall in Efficient Language Models. arXiv:2312.04927
  9. Bischof, Van Loan (1987). The WY Representation for Products of Householder Matrices. SIAM J. Sci. Stat. Comput. 8.