How transformers, RNNs and SSMs are more alike than you think

By uncovering surprising links between seemingly unrelated LLM architectures, a way might be paved for effective idea exchange and boosting efficiency.

Even in the age of Mamba and other exciting linear RNNs and state space models (SSMs), transformer architectures still reign supreme as the backbone of LLMs. However, this might soon change: hybrid architectures such as Jamba, Samba and Griffin are quite promising. They are significantly more time- and memory-efficient than transformers, and their capabilities are not greatly diminished compared to attention-based LLMs.

At the same time, recent research has exposed deep connections between different architectural options: transformers, RNNs, SSMs and matrix mixers. This is very exciting because it allows for the transfer of ideas from one architecture to another. In this long read, we’ll mainly follow Transformers are RNNs and Mamba 2, getting elbows deep in algebra to understand how:

  • Transformers may sometimes be RNNs (section 2),
  • State space models may hide inside the mask in the self-attention mechanism (section 4),
  • Mamba may sometimes be rewritten as masked self-attention (section 5).

I find these connections quite fascinating, and I hope you, too, will appreciate them.

If you’re curious to learn more about the architectural development of transformer alternatives, please feel free to check my separate post about them.

Masked self-attention in LLMs: a reminder

A classic LLM self-attention layer looks as follows:

Or, in more detail:

Let’s briefly recall what happens here:

  • We multiply matrices QQ of queries and KTK^T of keys. The product is an L×LL\times L matrix of scalar products of queries and keys.
  • We normalize the resulting matrix.
  • Then, we multiply it elementwise on the L×LL\times L attention mask. The picture shows the default causal mask — the 0-1 matrix in the left. It zeroes products (earlier query, later key), preventing the attention from “looking into the future”.
  • Then, we take softmax.
  • Finally, we multiply the attention weights (let’s denote this matrix by AA) on the values VV; the t‑th row of the output is:

outputt=iativi.output_t = \sum_i a_{ti}v_i.

Thus, ii-th value is taken with the “attention weight of tt-th query for ii-th key”.

Many design choices in this architecture may be altered. Let’s explore some possible changes.

Linearized attention

Softmax in the attention formula ensures that the values are mixed with positive coefficients that sum to 1. This is good because it keeps some of the statistics intact, but it is very restricting; no term would escape its brackets, even if we would love to play with associativity like (QKT)V=Q(KTV)(QK^T)V = Q(K^TV). Why do we need associativity, you ask? The thing is that changing the order of multiplication can drastically influence the complexity:

  • In the left formula, we need to calculate an L×LL\times L matrix, which gives O(L2d)O(L^2d) complexity cost and O(L2)O(L^2) memory cost if we allow this matrix to manifest fully in memory.
  • In the right formula, we need to calculate a d×dd\times d matrix, which gives O(Ld2)O(Ld^2) complexity cost and O(d2)O(d^2) memory cost if we allow this matrix to manifest fully in memory.

Thus, as context length grows, the first formula becomes more costly than the second one and ultimately becomes prohibitively expensive.

So, you know what? Let’s crack this softmax open! Let’s write down the formula with the softmax in a bit more detail:

yt=s=1Lexp(mtsqtksT/d)j=1Lexp(mtjqtkjT/d)vsy_t = \sum_{s=1}^L \frac{\exp \left(m_{ts}\cdot q_t k_s^T/ \sqrt{d}\right)}{\sum_{j=1}^L \exp \left(m_{tj}\cdot q_t k_j^T/\sqrt{d}\right)} v_s

, where exp(mtsqtksT/d)j=1Lexp(mtjqtkjT/d)\frac{\exp \left(m_{ts}\cdot q_t k_s^T/ \sqrt{d}\right)}{\sum_{j=1}^L \exp \left(m_{tj}\cdot q_t k_j^T/\sqrt{d}\right)} is softmax.

The most inconvenient thing here is the exponent; we just can’t take anything out of it. Let’s just erase it.

yt=s=1LmtsqtksTj=1LmtjqtkjTvs.y_t = \sum_{s=1}^L\frac{m_{ts}\cdot q_tk_s^T}{\sum_{j=1}^Lm_{tj}\cdot q_tk_j^T}v_s.

Note that the normalizer 1d\frac1{\sqrt{d}} vanished.

The problem with this formula is that qtkSTq_tk_S^T is not guaranteed to be positive, so we may end up mixing the values with different-sign coefficients (which is counterintuitive); and even worse, the denominator may vanish, breaking everything. We may add a “good” elementwise function ϕ\phi (a kernel) to mitigate it:

yt=s=1Lmtsϕ(qt)ϕ(ksT)j=1Lmtjϕ(qt)ϕ(kjT)vs.y_t = \sum_{s=1}^L\frac{m_{ts}\cdot \phi(q_t)\phi(k_s^T)}{\sum_{j=1}^Lm_{tj}\cdot\phi(q_t)\phi(k_j^T)}v_s.

The original work suggested ϕ(x)=1+elu(x)\phi(x) = 1 + \mathrm{elu}(x).

This variation of the attention mechanism is called linearized attention. And the good thing is that we can now play with associativity:

yt=ϕ(qt)s=1Lmtsϕ(ksT)vsϕ(qt)j=1Lmtjϕ(kjT).y_t = \frac{\phi(q_t)\sum_{s=1}^Lm_{ts}\cdot \phi(k_s^T)v_s}{\phi(q_t)\sum_{j=1}^Lm_{tj}\cdot\phi(k_j^T)}.

A word of caution. The relations between MM, KTK^T and VV in the brackets are now quite complex and are not expressed through just ordinary and elementwise matrix products. We’ll discuss this computational block in more detail in the next section.

Now, if MM is the causal mask with 11 on the diagonal and below, and 00 above the diagonal

then everything becomes even simpler:

yt=ϕ(qt)s=1tϕ(ksT)vsϕ(qt)j=1tϕ(kjT).y_t = \frac{\phi(q_t)\sum_{s=1}^t\phi(k_s^T)v_s}{\phi(q_t)\sum_{j=1}^t\phi(k_j^T)}.

And this can be calculated in a very simple, recurrent way:

I think you won’t be surprised to learn that the ICML 2020 paper where linearized attention was first suggested is called Transformers are RNNs. Note that we have two hidden states here: a vector ztz_t and a matrix hth_t (ϕ(kt)Tvt\phi(k_t)^Tv_t is column\cdotrow, so it’s a d×dd\times d matrix). In recent papers, lnearized attention is often presented in an even simpler, distilled way, with no ϕ\phi and no denominator:

yt=qts=1tmtsksTvs.y_t = q_t\sum_{s=1}^tm_{ts}k_s^Tv_s.

Linearized attention has two important benefits:

  • As a recurrent mechanism, it has linear complexity at inference time with respect to the sequence length LL,
  • As a transformer model, it can be efficiently trained in a parallel way.

At this point, you probably have a pressing question: if linearized attention is so cool, why don’t we see it in every LLM and why do we still complain about attention’s quadratic complexity? Alas, linearized attention–based LLMs are less stable during training and reach somewhat lower capability than standard self-attention. It looks as if less information can be squeezed through the fixed‑d×d‑shaped bottleneck than through the adaptable L×L one. You can check experiment results in the this article at Manifest AI, for example.

Further reading. The idea that RNNs and linearized attention are related is rediscovered and revisited in many recent works. Moreover, it’s becoming common in such papers to have a matrix hidden state that is updated as follows:

ht=ht1+ktTvt,h_t = h_{t-1} + k_t^Tv_t,

where ktk_t and vtv_t are a “key” and a “value” of sort, with the output of an RNN layer being something like
qtht=qt(ht1+ktTvt)==q_th_t = q_t(h_{t-1} + k_t^T v_t) = \ldots =
=qt(k1Tv1++ktTvt)==q_t(k_1^Tv_1 + \ldots + k_t^T v_t) =
=(qtk1T)v1++(qtktT)vt,=(q_tk_1^T)v_1 + \ldots + (q_tk_t^T)v_t,

which boils down to linear attention.

As examples, I would suggest checking these two papers:

  • xLSTM — a May 2024 paper suggesting a development of the renowned LSTM recurrent architecture. Its mLSTM block features a matrix hidden state which is updated the already-familiar way:
    ct=forget_gatetct1+input_gatetvtktTc_t = forget\_gate_t \cdot c_{t-1} + input\_gate_t \cdot v_t k_t^T

and is further multiplied on a “query” to obtain the output. (The paper’s linear algebra setup is transposed to ours: queries, keys and values are columns, not rows; hence the odd order in vtktTv_tk_t^T.)

  • Learning to (learn at test time), emerging in July 2024, which is another RNN architecture with a matrix hidden state and a very peculiar semantics: it is the parameter WW of a function which is optimized by gradient descent as we iterate over tt:

Here, everything is also transposed, so the order is odd again. The math is more complicated than Wt=Wt1+vtktTW_t = W_{t-1} + v_tk_t^T, but may be simplified to it.

Exploring attention masks

Now that we have simplified the masked attention mechanism, we can start developing it again. And there is an obvious room for experiments: choosing a different lower triangular (never look into the future) mask MM, instead of the ascetic 0-1 causal mask. But first, we need to deal with the emerging inefficiency.

In the previous section, we were actually very lucky to work with a simple 0-1 causal mask MM. In general case, the recurrent trick doesn’t work anymore:
y1=ϕ(q1)(m11ϕ(k1T)v1)ϕ(qt)(m11ϕ(k1T))y_1=\frac{\phi(q_1)\left({\color{red}{m_{11}}}\cdot \phi(k_1^T)v_1\right)}{\phi(q_t)\left({\color{red}{m_{11}}}\cdot\phi(k_1^T)\right)}

y2=ϕ(q2)(m21ϕ(k1T)v1+m22ϕ(k2T)v2)ϕ(q2)(m21ϕ(k1T)+m22ϕ(k2T))y_2=\frac{\phi(q_2)\left({\color{red}{m_{21}}}\cdot \phi(k_1^T)v_1 + {\color{red}{m_{22}}}\cdot \phi(k_2^T)v_2\right)}{\phi(q_2)\left({\color{red}{m_{21}}}\cdot \phi(k_1^T) + {\color{red}{m_{22}}}\cdot \phi(k_2^T)\right)}

y3=ϕ(q3)(m31ϕ(k1T)v1+m32ϕ(k2T)v2+m33ϕ(k3T)v3)ϕ(q3)(m31ϕ(k1T)+m32ϕ(k2T)+m33ϕ(k3T))y_3=\frac{\phi(q_3)\left({\color{red}{m_{31}}}\cdot \phi(k_1^T)v_1 + {\color{red}{m_{32}}}\cdot \phi(k_2^T)v_2 + {\color{red}{m_{33}}}\cdot \phi(k_3^T)v_3\right)}{\phi(q_3)\left({\color{red}{m_{31}}}\cdot \phi(k_1^T) + {\color{red}{m_{32}}}\cdot \phi(k_2^T) + {\color{red}{m_{33}}}\cdot \phi(k_3^T)\right)}
\vdots

The coefficients mts\color{red}{m_{ts}} are different now, and there’ll be no recurrent formula relating y3y_3 to y2y_2. Thus, for each tt we’ll have to compute the sum from scratch, making the complexity asymptotic quadratic by LL again, instead of linear.

The answer is that we won’t use just any masks MM, only special, “good” ones. Namely, the ones that can be quickly multiplied (not elementwise!) on other matrices. To see how we can benefit from this property, let’s understand how to calculate

yt=qts=1LmtsksTvsy_t = q_t\sum_{s=1}^Lm_{ts}k_s^T v_s

efficiently. First, let’s clarify what’s happening here:

If we descend to the individual index level:

For the further discussion, it will be useful for us to color indices instead of blocks:

Now, we are ready to formulate the four-step algorithm:

Step 1. Take KK and VV and create a three-dimensional tensor ZZ with
ztij=ktivtjz_{{\color{magenta}{t}}{\color{orange}{i}}{\color{teal}{j}}} = k_{{\color{magenta}{t}}\color{orange}{i}}v_{{\color{magenta}{t}}{\color{teal}{j}}}

(Each axis is marked with its length.) This step takes O(Ld2)O(Ld^2) time and memory.

Note that if we sum this tensor over the magenta axis t{\color{magenta}{t}}, we’ll get the product KTVK^TV:
tktivtj=t(KT)itvtj=(KTV)ij\sum_{{\color{magenta}{t}}}k_{{\color{magenta}{t}}\color{orange}{i}}v_{{\color{magenta}{t}}{\color{teal}{j}}} = \sum_{{\color{magenta}{t}}}(K^T)_{\color{orange}{i}{\color{magenta}{t}}}v_{{\color{magenta}{t}}{\color{teal}{j}}}=(K^TV){\color{orange}{i}{\color{teal}{j}}}
Step 2. Now, we multiply MM on this tensor (not elementwise!) in such a way that MM is multiplied on each “column” of ZZ along the magenta axis t{\color{magenta}{t}}.

This gives us exactly

tmstktivtj=tmst(KT)itvtj\sum_{{\color{magenta}{t}}} m_{{\color{cyan}{s}}{\color{magenta}{t}}} k_{{\color{magenta}{t}}\color{orange}{i}} v_{{\color{magenta}{t}}{\color{teal}{j}}} = \sum_{{\color{magenta}{t}}} m_{{\color{cyan}{s}}{\color{magenta}{t}}} (K^T)_{\color{orange}{i}{\color{magenta}{t}}} v_{{\color{magenta}{t}}{\color{teal}{j}}}

Let’s denote the result by HH. Now, it’s only left to multiply everything on qtq_t which will be done in two more steps.
Step 3a. Take QQ and multiply elementwise on each j=const{\color{teal}{j}} = \mathrm{const} layer of HH:

This results in:
qsitmstktivtjq_{{\color{cyan}{s}}{\color{orange}{i}}}\sum_{{\color{magenta}{t}}} m_{{\color{cyan}{s}}{\color{magenta}{t}}} k_{{\color{magenta}{t}}\color{orange}{i}} v_{{\color{magenta}{t}}{\color{teal}{j}}}

This step requires O(Ld2)O(Ld^2) time and memory.
Step 3b. Sum the resulting tensor along the i{\color{orange}{i}} axis:

This step also requires O(Ld2)O(Ld^2) time and memory. And we obtain exactly the thing we need:

iqsitmstktivtj\sum_{\color{orange}{i}}q_{{\color{cyan}{s}}{\color{orange}{i}}}\sum_{{\color{magenta}{t}}} m_{{\color{cyan}{s}}{\color{magenta}{t}}} k_{{\color{magenta}{t}}\color{orange}{i}} v_{{\color{magenta}{t}}{\color{teal}{j}}}

The most crucial step is the second one, and we deliberately omitted its complexity analysis. A naïve estimate is:

  • O(L2)O(L^2) for each matrix multiplication,

  • Repeating it d2d^2 times,

So it’s a whopping O(L2d2)O(L^2d^2). But we will take not just any MM. We will take MM such that the complexity of multiplying MM on a vector is O(RL)O(RL) for some constant (and not very large) RR.

For example, if MM is the 0-1 causal matrix, multiplication on it boils down to computing the cumulative sums, and it is done in O(L)O(L) time. But there are many more options for structured matrices with fast vector multiplication.

In the next section, we’ll discuss one of such matrix types — semiseparable matrices, which will turn out to be tightly connected with state space models.

Semiseparable matrices and state space models

Let’s recall that (discretized) state space models (SSM) are a class of sequential models connecting 11-dimensional input xtx_t, rr-dimensional hidden state hth_t, and 11-dimensial output utu_t in the following way:
ht=Atht1+Btxt,h_t= A_t h_{t-1} + B_t x_t,
yt=Ctht+Dtxty_t = C_t h_t + D_t x_t

So, in the discretized form, it’s just a fancy linear RNN with a skip connection. Moreover, for the rest of the story, we’ll even forget about skip connections by setting Dt=0D_t = 0.

Now, let’s write our SSM as a single matrix multiplication
y=Mx,y = \mathcal{M}x,
where

So, M\mathcal{M} is lower triangular, just like the attention masks we discussed earlier.

This type of matrix comes with a great boon:

An L×LL\times L lower triangular matrix whose elements may be presented in this way can be stored using O(rL)O(rL) memory and has a matrix-vector multiplication complexity of O(rL)O(rL) instead of the default O(L2)O(L^2) (see this paper for details).

This means that every state space model gives rise to a structured attention mask M\mathcal{M} which can be used in an efficient transformer model with linearized attention.

The curious thing, though, is that a semiseparable matrix M\mathcal{M} is already quite complex and expressive, even without all the query-key-value machinery around it. Moreover, it may itself be a masked attention mechanism. We’ll see this in the next section.

State apace suality

Here, we finally arrive at one of the central results of the Mamba 2 paper.

Let’s again consider y=Mu,y = \mathcal{M}u, where u=u(x)u = u(x) is a function of the inputs and M\mathcal{M} is a semiseparable matrix. Moreover, let’s consider a very special case where each AtA_t is a scalar matrix: At=atIA_t = a_tI. Then the formulas become especially simple:

yt=ata2CtB1u1+ata3CtB2u2++atCtBt1ut1+CtBtut.y_t = a_t\ldots a_2C_tB_1u_1 + a_t\ldots a_3C_tB_2u_2 + \ldots + a_tC_tB_{t-1}u_{t-1} + C_tB_tu_t.

Note that

Ci1×rBjr×1\underbrace{C_i}_{1\times r}\underbrace{B_j}_{r\times 1}

is just a number. Moreover, we can stack CiC_i and BiB_i in matrices BB and CC such that:

Now, we’ll also need the matrix

Then it’s easy to check that
Mu=(GCB)u\mathcal{M}u = (G\otimes CB)u

Have we seen this somewhere? Indeed! It’s masked attention with:

  • GG for the mask,
  • CC for QQ, the query matrix,
  • BB for KTK^T, the transposed key matrix,
  • uu for VV, the value matrix.

In classic SSMs, BB and CC are constants, but Mamba made them data-dependent, reinforcing the alignment. This correspondence between specific state space models and masked attention was introduced in the Mamba 2 paper as state space duality.

Further reading. The idea of using matrix mixers instead of more sophisticated architectures is not new, with one of the notable early examples being MLP-Mixer which employed MLPs instead of convolution or attention for spatial mixing in CV tasks.

Although current research is mostly focused on LLMs, there are also several papers proposing non-transformer, matrix mixture architectures for encoder models. Examples include:

  • FNet from Google research a with matrix mixer M\mathcal{M} based on Fourier transform,
  • Hydra, which, among other ideas, proposed an adaptation of semiseparable matrices for non-causal (non-triangular) working mode.

As always, I encourage you to check out our Practical Generative AI course, where the discussions often serve as a source of knowledge and inspiration for posts like this one.

author
Stanislav Fedotov
AI evangelist at Nebius, AI program lead at AI DT School
Sign in to save this post