skip to content

Softmax attention is a local constant estimator

From weighted averages to local linear corrections and Parallax

TL;DR

Softmax attention can be read as a kernel regression estimator: given a query qq, it predicts the value at qq by taking a similarity-weighted average of all previous values. Someone from statistics would call this also a Nadaraya-Watson estimator, and it is a local constant estimator: around each query, we fit a constant value; meaning other query keys q~\tilde q will have roughly the same value as qq. In this context “local” means not sequence locality, but locality in key/query representation space.

This lets us construct a failure mode. A weighted average works well when the nearby values are roughly flat around the query (assuming our query/keys are in 1D). But if the values have a local slope, and especially if the query is near the boundary of the key cloud, the weighted average lags behind the correct value. Local Linear Attention [quote] fixes this by fitting an affine function instead of a constant. Interestingly, instead of solving Local Linear Attention, Parallax [quote] adds a learned query-like projection R=WRxR = W_R x on top of ordinary softmax attention to estimate the slope.

See how softmax attention is lagging the actual trend. On X from @tilderesearch.

1. Starting from softmax

For one query token, we have previous key-value pairs

(k1,v1),,(kt,vt)(k_1, v_1), \ldots, (k_t, v_t)

and a query vector qq. Softmax attention returns

o(q)=j=1tpj(q)vjo(q) = \sum_{j=1}^{t} p_j(q)\, v_j

with

pj(q)=exp(qkj/h)=1texp(qk/h).p_j(q) = \frac{\exp(q^\top k_j / h)} {\sum_{\ell=1}^{t} \exp(q^\top k_\ell / h)}.

Here hh is the bandwidth, or in transformer language the inverse attention temperature. In vanilla attention we usually write the scale as 1/dhead1/\sqrt{d_{\text{head}}}, so this corresponds to h=dheadh=\sqrt{d_{\text{head}}} (Vaswani et al., 2017) Attention Is All You Need Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, Polosukhin · NeurIPS 2017 arXiv:1706.03762 .

Small hh makes attention sharp. Large hh makes attention diffuse:

h smallnearest/similar keys dominateh largemany keys get nontrivial weighthuniform average\begin{array}{rcl} h \text{ small} &\Rightarrow& \text{nearest/similar keys dominate} \\ h \text{ large} &\Rightarrow& \text{many keys get nontrivial weight} \\ h \to \infty &\Rightarrow& \text{uniform average} \end{array}

This answers the first small but important confusion: softmax is global over the causal context, but the weights are local in representation space. Every previous token is eligible, but keys with larger qkjq^\top k_j get exponentially larger weights. Inverse but similar how we treat outliers in linear regression with L2 loss.

For two keys kak_a and kbk_b,

papb=exp((qkaqkb)/h).\frac{p_a}{p_b} = \exp((q^\top k_a - q^\top k_b)/h).

If qkaq^\top k_a is larger by 2h2h, then kak_a gets e27.4×e^2 \approx 7.4\times more weight than kbk_b. That is the “nearby keys get larger weights” statement. Nearby means “nearby under the attention score”, not nearby in token position.

2. Softmax as weighted constant fitting

Now ignore the transformer vocabulary for a moment and pretend we are doing regression. The keys are input points, the values are labels, and qq is the test point where we want a prediction:

kjvj,q?k_j \mapsto v_j, \qquad q \mapsto ?

Suppose that around this query qq, the function from keys to values is approximately constant:

f(k)c.f(k) \approx c.

For this one query, fit cc by minimizing the weighted squared error

L(c)=j=1twj(q)cvj2L(c) = \sum_{j=1}^{t} w_j(q)\, \|c - v_j\|^2

where

wj(q)=exp(qkj/h).w_j(q) = \exp(q^\top k_j / h).

The derivative is

cL=2jwj(cvj).\nabla_c L = 2 \sum_j w_j(c - v_j).

Set it to zero:

cjwj=jwjvjc \sum_j w_j = \sum_j w_j v_j

so

c=jwjvjjwj.c = \frac{\sum_j w_j v_j}{\sum_j w_j}.

Define normalized weights

pj=wjwp_j = \frac{w_j}{\sum_\ell w_\ell}

and we get

c=jpjvj\boxed{ c = \sum_j p_j v_j }

which is exactly softmax attention.

This is what people mean when they say softmax attention is the Nadaraya-Watson estimator (Wang, Shi, Fox, 2025) Test-time regression: a unifying framework for designing sequence models with associative memory Wang, Shi, Fox · 2025 arXiv:2501.12352 . It is kernel regression where the kernel is

K(q,kj)=exp(qkj/h).K(q, k_j) = \exp(q^\top k_j / h).

The estimator is local because the weights depend on similarity to qq. It is constant because the local model we fit is just one value cc.

3. Where the lag comes from

The weighted average is not wrong by default. If the values are truly flat around qq, the local constant estimator is the right estimator. The problem appears when the values have a local trend.

Let’s use a 1D key for intuition. Assume values come from a smooth function:

vj=f(kj).v_j = f(k_j).

Softmax/Nadaraya-Watson predicts

f^(q)=jpjf(kj).\hat f(q) = \sum_j p_j f(k_j).

Now Taylor expand f(kj)f(k_j) around qq:

f(kj)=f(q)+f(q)(kjq)+12f(q)(kjq)2+.f(k_j) = f(q) + f'(q)(k_j-q) + \frac{1}{2}f''(q)(k_j-q)^2 + \cdots.

Plug this into the weighted average:

f^(q)=jpj[f(q)+f(q)(kjq)+12f(q)(kjq)2+].\hat f(q) = \sum_j p_j \left[ f(q) + f'(q)(k_j-q) + \frac{1}{2}f''(q)(k_j-q)^2 + \cdots \right].

Because jpj=1\sum_j p_j = 1,

f^(q)f(q)=f(q)jpj(kjq)+12f(q)jpj(kjq)2+.\hat f(q)-f(q) = f'(q)\sum_j p_j(k_j-q) + \frac{1}{2}f''(q)\sum_j p_j(k_j-q)^2 + \cdots.

Let

kˉ=jpjkj.\bar k = \sum_j p_j k_j.

Then the leading bias is

f^(q)f(q)f(q)(kˉq)\boxed{ \hat f(q)-f(q) \approx f'(q)(\bar k-q) }

This is the lag.

If the weighted neighborhood is symmetric around qq, then kˉq\bar k \approx q and this first-order bias disappears. But if qq sits near the boundary of the key cloud, most of the mass can lie on one side. If qq is to the right of the weighted key center, then

kˉq<0.\bar k - q < 0.

If the function is increasing, then

f(q)>0.f'(q) > 0.

So

f^(q)f(q)<0.\hat f(q)-f(q) < 0.

The local average underpredicts. It lags behind the trend.

4. Local linear attention

The natural next move is simple: instead of fitting a local constant, fit a local affine function.

Softmax attention fits

f(k)b.f(k) \approx b.

Local Linear Attention fits

f(k)b+a(kq)f(k) \approx b + a(k-q)

in 1D, or

f(k)b+W(kq)f(k) \approx b + W(k-q)

in higher dimensions (Zuo et al., 2025) Local Linear Attention: An Optimal Interpolation of Linear and Softmax Attention For Test-Time Regression Zuo, Yin, Zeng, Li, Zhu, Wang · 2025 arXiv:2510.01450 .

For the 1D case, minimize

jpj(b+a(kjq)vj)2.\sum_j p_j \bigl(b + a(k_j-q) - v_j\bigr)^2.

At the query itself, k=qk=q, so the fitted value is

f(q)=b+a(qq)=b.f(q) = b + a(q-q) = b.

So the linear term does not appear directly in the final evaluation point. But it changes the fitted intercept bb, because aa and bb are fitted together. The slope absorbs the local trend, so the intercept no longer has to equal the raw weighted average.

The normal equations give the useful form

b=vˉa(kˉq)b = \bar v - a(\bar k-q)

with

aCovp(k,v)Varp(k)+λ.a \approx \frac{\mathrm{Cov}_p(k,v)} {\mathrm{Var}_p(k)+\lambda}.

So the local linear prediction is

oLLA=vˉCovp(k,v)Varp(k)+λ(kˉq)\boxed{ o^{LLA} = \bar v - \frac{\mathrm{Cov}_p(k,v)} {\mathrm{Var}_p(k)+\lambda} (\bar k-q) }

where

vˉ=jpjvj\bar v = \sum_j p_j v_j

is the usual softmax output.

Look at the correction term. If the query is to the right of the weighted key center, kˉq<0\bar k-q<0, and values increase with keys, Covp(k,v)>0\mathrm{Cov}_p(k,v)>0, then the correction is positive. It raises the softmax estimate and corrects the lag.

This is unrelated to “linear attention” in the finite-state sense from my earlier posts. Annoying naming collision. Local Linear Attention is not saying “sub-quadratic recurrent attention”. It is saying “local linear regression around each query”.

5. What exact LLA costs

In higher dimensions, the local linear correction has the shape

oiLLA=oiSA(1+ηi)ΣKV(i)ρi.o_i^{LLA} = o_i^{SA} - (1+\eta_i)\Sigma_{KV}^{(i)}\rho_i^\star.

Here

oiSA=vˉio_i^{SA} = \bar v_i

is the ordinary softmax output,

ΣKV(i)=Epi[(vjvˉi)(kjkˉi)]\Sigma_{KV}^{(i)} = \mathbb E_{p_i}\left[(v_j-\bar v_i)(k_j-\bar k_i)^\top\right]

is the local key-value covariance under the softmax weights, and ρi\rho_i^\star is the direction that says how to probe that covariance.

Exact Local Linear Attention gets this probe by solving a linear system:

ρi=Σi1μi.\rho_i^\star = \Sigma_i^{-1}\mu_i.

That is the beautiful version mathematically and the annoying version computationally. The Parallax paper lists the practical problems:

  1. A per-query solve is expensive and I/O-heavy.
  2. The regularization λ\lambda is a real tradeoff. Too large and the correction collapses back toward softmax; too small and the solve can become unstable.
  3. Iterative solvers are not what modern low-precision attention kernels want to spend their life doing.

So we have the correction we want, but the exact computation is awkward.

6. Parallax: learn the probe

Parallax (Zuo et al., 2026) Parallax: Parameterized Local Linear Attention for Language Modeling Zuo, Pai, Zeng, Dewulf, Hu, Wang · 2026 arXiv:2605.29157 keeps the softmax part and replaces the exact solve with a learned projection.

Normal attention has

Q=WQx,K=WKx,V=WVx.Q = W_Q x, \qquad K = W_K x, \qquad V = W_V x.

Parallax adds

R=WRx.R = W_R x.

For a query position ii, write

ρi=WRxi.\rho_i = W_R x_i.

Then Parallax uses

oiPLX=oiSAΣKV(i)ρi\boxed{ o_i^{PLX} = o_i^{SA} - \Sigma_{KV}^{(i)}\rho_i }

So the model no longer solves for ρi\rho_i^\star. It learns a query-like probe ρi\rho_i during training.

Softmax gives the local weighted average. Parallax subtracts a learned covariance correction.

The paper also drops the (1+ηi)(1+\eta_i) boundary amplification factor from exact LLA. That factor is well behaved when ρi\rho_i^\star comes from the exact constrained solve. Once ρi\rho_i is just WRxiW_Rx_i, the same normalization can blow up or flip sign.

Parallax is not replacing softmax the way linear attention or DeltaNet replace softmax. It still uses softmax weights pijp_{ij} and still streams the full KV cache. Instead, it changes the estimator sitting on top of the same softmax neighborhood:

Softmax:local constant estimatorExact LLA:local linear estimator with solved probeParallax:local linear correction with learned probe\begin{array}{rcl} \text{Softmax} &:& \text{local constant estimator} \\ \text{Exact LLA} &:& \text{local linear estimator with solved probe} \\ \text{Parallax} &:& \text{local linear correction with learned probe} \end{array}

This matters because softmax can only de-emphasize a token by assigning it small positive probability. Parallax can subtract value directions. In the paper’s analysis, this changes attention patterns: the base softmax becomes more diffuse, the correction branch handles sharper discrimination, and the model relies less on the usual attention sink behavior.

7. Why more FLOPs can be faster

The other initially weird claim in the Parallax paper is that the model does more FLOPs but can match or beat FlashAttention decoding.

This is possible because decoding attention is often memory-bandwidth-bound. The slow part is not always the arithmetic; it is streaming the KV cache from HBM.

FlashAttention loads K/V tiles and computes

S1=QK,O1=PV.S_1 = QK^\top, \qquad O_1 = PV.

Parallax adds a second branch

S2=RK,P2=P1S2,O2=P2V.S_2 = RK^\top, \qquad P_2 = P_1 \odot S_2, \qquad O_2 = P_2V.

But it reuses the same K/V tiles that were already loaded for the softmax branch. So the kernel performs more flops per byte moved from memory. We have higher aritmhetic intensity

arithmetic intensity=FLOPsbytes moved from HBM.\text{arithmetic intensity} = \frac{\text{FLOPs}}{\text{bytes moved from HBM}}.

but if the GPU was waiting on memory, doing more flops on already-loaded data can be close to free, or even faster if it improves utilization.

References

  1. Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, Polosukhin (2017). Attention Is All You Need. NeurIPS 2017. arXiv:1706.03762
  2. Wang, Shi, Fox (2025). Test-time regression: a unifying framework for designing sequence models with associative memory. arXiv:2501.12352
  3. Zuo, Yin, Zeng, Li, Zhu, Wang (2025). Local Linear Attention: An Optimal Interpolation of Linear and Softmax Attention For Test-Time Regression. arXiv:2510.01450
  4. Zuo, Pai, Zeng, Dewulf, Hu, Wang (2026). Parallax: Parameterized Local Linear Attention for Language Modeling. arXiv:2605.29157