Since 2017, the Transformer’s attention mechanism has remained largely unchanged. Most attempts to improve efficiency have aimed to completely swap out softmax attention. A new study takes an alternative approach: it preserves softmax attention and attaches an auxiliary correction module.
Researchers from Northwestern University, Tilde Research, and the University of Washington present a parameterized Local Linear Attention architecture named ‘Parallax.’ It is designed to scale to large language model pretrained and co-designed with the Muon optimizer.
Parallax does not pursue efficiency by reducing computation. Instead, it intentionally adds extra computation, then optimizes that computation to run more cheaply on contemporary GPUs.
What is Parallax
Parallax is built upon Local Linear Attention (LLA). LLA originates from a test-time regression framework, which interprets attention as a regression solver operating over key-value pairs.
Under this interpretation, keys serve as training data points, values act as labels, and the query functions as the test point. Softmax attention is a nonparametric estimator known as Nadaraya-Watson, which fits a local constant function for each query.
LLA enhances this local constant estimate by upgrading it to a local linear estimate. The research team demonstrates that this refinement produces a strictly lower integrated mean squared error, offering improved bias-variance tradeoffs for associative memory tasks.
However, LLA faces scalability challenges. Its exact forward pass requires solving a linear system for each query using a parallel conjugate gradient (CG) solver. This solver introduces three problems: heavy I/O demands, a difficult balance between regularization and expressiveness, and incompatibility with low-precision arithmetic.
Parallax eliminates the solver entirely. Instead, it trains an additional projection matrix, expressed as ρi = WRxi. Here, WR is a learnable matrix that directly probes the KV covariance from the layer’s input.
Thus, Parallax retains the local linear principle but substitutes the per-query solve with a learned, query-like projector. This makes the approach simpler, more efficient, and easier to implement.
How the Mechanism Works
Parallax restructures LLA as softmax attention combined with an additive correction. The output consists of the softmax attention result minus a projected covariance term, represented in the paper as the KV covariance multiplied by the learned probe ρi.
The team also removes a component of LLA known as the boundary amplification factor, setting it to zero. This adjustment is essential for stability. Once the probe becomes parametric, the original geometric interpretation no longer holds. Retaining the factor could cause the scaling to diverge or reverse sign.
Parallax belongs to a broader family of attention mechanisms. The researchers categorize them along three dimensions: bandwidth, probe construction, and affine structure. At one end of the spectrum, Parallax collapses exactly into softmax attention when the probe norm approaches zero.
Setting WR = 0 causes a Parallax layer to behave identically to softmax attention. This means a pretrained Transformer checkpoint can be adapted by introducing WR and fine-tuning.
The Hardware Argument
Parallax adopts the streaming architecture of FlashAttention, adding a covariance branch that leverages the same key-value data stream.
The team decomposes the forward pass into two parallel scoring branches. Both branches share the online maximum, the rescaling factor, and the K and V tiles. Consequently, Parallax requires no additional I/O per iteration.
The critical advantage is increased arithmetic intensity (AI)—the ratio of floating-point operations to high-bandwidth memory traffic. In scenarios where KV work dominates, Parallax approximately doubles the arithmetic intensity by adding computation while reusing the same memory stream.
This transitions attention into a more compute-bound regime, which is precisely where kernel optimization delivers gains on modern hardware.
The team developed a decode kernel prototype in CuTeDSL for NVIDIA Hopper GPUs. Hopper’s tensor core matmul instructions process tiles of at least 64 rows, while a decode step provides only one query row. Therefore, the QK and RK products can be calculated together within instructions that standard attention already uses.
They benchmarked against FlashAttention 2 and 3 on H200 GPUs using BF16 precision, testing batch sizes from 1 to 2,048 and context lengths from 128 to 32,768. The prototype kernel matches or exceeds FlashAttention performance across all tested configurations. The figure below highlights speedups of 1.54× in the compute-matched scenario and 1.14× in the I/O-matched scenario.

What the Experiments Reveal
The team tested Parallax on synthetic benchmarks and during LLM pretraining at 0.6B and 1.7B parameter scales. All models relied on the Qwen-3 architecture through the torchtitan framework. Training utilized the Ultra-FineWeb dataset with sequences of length 4096. Comparison methods spanned softmax attention (Transformer), Mamba, Gated DeltaNet, MesaNet, and Kimi DeltaAttention.
Parallax achieved the top overall accuracy on the MAD-Benchmark, reaching 0.716 on average. It delivered steady improvements on tasks requiring recall, such as In-Context-Recall and Selective-Copying. On compression and memorization tasks, it remained competitive with other approaches.
For language modeling, Parallax paired with Muon delivered the lowest perplexity at both model sizes. It also posted the highest average accuracy across downstream evaluations. At the 1.7B scale, Parallax reached 62.45 compared to the Transformer’s 61.43.
Two ablation studies help identify the source of these improvements. A Transformer matched for parameter count closed only a small portion of the performance gap. Meanwhile, a compute-matched Parallax continued to outperform both baselines. The paper concludes these results highlight the mechanism itself rather than additional parameters or computation.
The Optimizer Connection
A central insight from this work centers on how the optimizer interacts with the architecture. Parallax shows a substantial edge when trained with Muon. Under AdamW, that advantage narrows considerably or vanishes altogether.
Muon is a newer optimizer designed for matrix parameters in hidden layers. It derives updates from the polar factor of the momentum buffer, ensuring each update has a condition number of exactly one. Earlier research demonstrated this leads to weight matrices with better conditioning.
The team traced this performance gap to the correction branch. They introduced a correction-to-output ratio (COR) metric. In the deepest layers, COR surpasses 8 under Muon but remains below 4 under AdamW.
The WR projection is hit particularly hard. Its stable rank deteriorates under AdamW yet remains robust under Muon. A gating experiment backed up this finding. With AdamW, the model essentially learns to turn off the correction branch instead of leveraging it.
The team describe this as the first clear empirical evidence of meaningful architecture-optimizer codesign for attention layers. They do not argue Muon with WSD is the ideal training recipe. An ablation in the appendix reveals the advantage weakens during the decay phase.
How the Score Distributions Compare
Parallax also generates score patterns distinct from traditional softmax attention. Its per-token weights can dip below zero and climb above one in absolute value, which standard softmax weights cannot achieve.
The team observed three key differences. Parallax can actively subtract contributions from unimportant tokens. It notably reduces the attention sink effect on the first token. Its base softmax entropy stays higher, producing more spread-out attention weights.
Strengths, Weaknesses, and Unresolved Questions
Strengths
- Retains softmax attention as its foundation, so existing pretrained Transformers can transition simply by adding WR and fine-tuning.
- Introduces no additional I/O overhead per step by reusing the existing FlashAttention key-value stream.
- Doubles arithmetic intensity, with a prototype kernel that matches or surpasses FlashAttention 2/3 during decode.
- Delivers consistent improvements in perplexity and downstream performance across both parameter-matched and compute-matched comparisons.
Weaknesses and Unresolved Questions
- Performance gains rely heavily on Muon; with AdamW, the advantage largely vanishes.
- The exact reason behind this optimizer dependence is not yet fully understood.
- Experiments are limited to 1.7B scale, without testing MoE, longer contexts, or larger-scale runs.
- The advantage fades during the WSD decay phase, with only partial recovery from weight decay annealing.
Key Takeaways
- Parallax preserves softmax attention while introducing a learned covariance correction branch, serving as an alternative to LLA’s per-query conjugate gradient solver.
- It doubles arithmetic intensity while using the same KV stream, achieving a decode kernel on par with or faster than FlashAttention 2/3.
- Consistent gains in perplexity and downstream performance at 0.6B and 1.7B scales, confirmed through parameter-matched and compute-matched ablations.
- Performance improvements depend strongly on Muon; switching to AdamW causes the advantage to diminish or disappear.
- Setting WR to zero recovers standard softmax attention exactly, meaning pretrained Transformers can convert by adding WR and fine-tuning.
For more details, check out the Paper and Repo. Also, feel free to follow us on Twitter and don’t forget to join our 150k+ ML SubReddit and subscribe to our Newsletter. Already using Telegram? You can join us there as well.
Looking to collaborate on promoting your GitHub repo, Hugging Face page, product launch, webinar, or similar? Get in touch with us



