How transformers, RNNs and SSMs are more alike than you think
By uncovering surprising links between seemingly unrelated LLM architectures, a way might be paved for effective idea exchange and boosting efficiency.
August 21, 2024
10 mins to read
Even in the age of Mamba and other exciting linear RNNs and state space models (SSMs), transformer architectures still reign supreme as the backbone of LLMs. However, this might soon change: hybrid architectures such as Jamba, Samba and Griffin are quite promising. They are significantly more time- and memory-efficient than transformers, and their capabilities are not greatly diminished compared to attention-based LLMs.
At the same time, recent research has exposed deep connections between different architectural options: transformers, RNNs, SSMs and matrix mixers. This is very exciting because it allows for the transfer of ideas from one architecture to another. In this long read, we’ll mainly follow Transformers are RNNs and Mamba 2, getting elbows deep in algebra to understand how:
Transformers may sometimes be RNNs (section 2),
State space models may hide inside the mask in the self-attention mechanism (section 4),
Mamba may sometimes be rewritten as masked self-attention (section 5).
I find these connections quite fascinating, and I hope you, too, will appreciate them.
If you’re curious to learn more about the architectural development of transformer alternatives, please feel free to check my separate post about them.
A classic LLM self-attention layer looks as follows:
Or, in more detail:
Let’s briefly recall what happens here:
We multiply matrices Q of queries and KT of keys. The product is an L×L matrix of scalar products of queries and keys.
We normalize the resulting matrix.
Then, we multiply it elementwise on the L×Lattention mask. The picture shows the default causal mask — the 0-1 matrix in the left. It zeroes products (earlier query, later key), preventing the attention from “looking into the future”.
Then, we take softmax.
Finally, we multiply the attention weights (let’s denote this matrix by A) on the values V; the t‑th row of the output is:
outputt=i∑ativi.
Thus, i-th value is taken with the “attention weight of t-th query for i-th key”.
Many design choices in this architecture may be altered. Let’s explore some possible changes.
Softmax in the attention formula ensures that the values are mixed with positive coefficients that sum to 1. This is good because it keeps some of the statistics intact, but it is very restricting; no term would escape its brackets, even if we would love to play with associativity like (QKT)V=Q(KTV). Why do we need associativity, you ask? The thing is that changing the order of multiplication can drastically influence the complexity:
In the left formula, we need to calculate an L×L matrix, which gives O(L2d) complexity cost and O(L2) memory cost if we allow this matrix to manifest fully in memory.
In the right formula, we need to calculate a d×d matrix, which gives O(Ld2) complexity cost and O(d2) memory cost if we allow this matrix to manifest fully in memory.
Thus, as context length grows, the first formula becomes more costly than the second one and ultimately becomes prohibitively expensive.
So, you know what? Let’s crack this softmax open! Let’s write down the formula with the softmax in a bit more detail:
, where ∑j=1Lexp(mtj⋅qtkjT/d)exp(mts⋅qtksT/d) is softmax.
The most inconvenient thing here is the exponent; we just can’t take anything out of it. Let’s just erase it.
yt=s=1∑L∑j=1Lmtj⋅qtkjTmts⋅qtksTvs.
Note that the normalizer d1 vanished.
The problem with this formula is that qtkST is not guaranteed to be positive, so we may end up mixing the values with different-sign coefficients (which is counterintuitive); and even worse, the denominator may vanish, breaking everything. We may add a “good” elementwise function ϕ (a kernel) to mitigate it:
A word of caution. The relations between M, KT and V in the brackets are now quite complex and are not expressed through just ordinary and elementwise matrix products. We’ll discuss this computational block in more detail in the next section.
Now, if M is the causal mask with 1 on the diagonal and below, and 0 above the diagonal
then everything becomes even simpler:
yt=ϕ(qt)∑j=1tϕ(kjT)ϕ(qt)∑s=1tϕ(ksT)vs.
And this can be calculated in a very simple, recurrent way:
I think you won’t be surprised to learn that the ICML 2020 paper where linearized attention was first suggested is called Transformers are RNNs. Note that we have two hidden states here: a vector zt and a matrixht (ϕ(kt)Tvt is column⋅row, so it’s a d×d matrix). In recent papers, lnearized attention is often presented in an even simpler, distilled way, with no ϕ and no denominator:
yt=qts=1∑tmtsksTvs.
Linearized attention has two important benefits:
As a recurrent mechanism, it has linear complexity at inference time with respect to the sequence length L,
As a transformer model, it can be efficiently trained in a parallel way.
At this point, you probably have a pressing question: if linearized attention is so cool, why don’t we see it in every LLM and why do we still complain about attention’s quadratic complexity? Alas, linearized attention–based LLMs are less stable during training and reach somewhat lower capability than standard self-attention. It looks as if less information can be squeezed through the fixed‑d×d‑shaped bottleneck than through the adaptable L×L one. You can check experiment results in the this article at Manifest AI, for example.
Further reading. The idea that RNNs and linearized attention are related is rediscovered and revisited in many recent works. Moreover, it’s becoming common in such papers to have a matrix hidden state that is updated as follows:
ht=ht−1+ktTvt,
where kt and vt are a “key” and a “value” of sort, with the output of an RNN layer being something like qtht=qt(ht−1+ktTvt)=…= =qt(k1Tv1+…+ktTvt)= =(qtk1T)v1+…+(qtktT)vt,
which boils down to linear attention.
As examples, I would suggest checking these two papers:
xLSTM — a May 2024 paper suggesting a development of the renowned LSTM recurrent architecture. Its mLSTM block features a matrix hidden state which is updated the already-familiar way: ct=forget_gatet⋅ct−1+input_gatet⋅vtktT
and is further multiplied on a “query” to obtain the output. (The paper’s linear algebra setup is transposed to ours: queries, keys and values are columns, not rows; hence the odd order in vtktT.)
Learning to (learn at test time), emerging in July 2024, which is another RNN architecture with a matrix hidden state and a very peculiar semantics: it is the parameter W of a function which is optimized by gradient descent as we iterate over t:
Here, everything is also transposed, so the order is odd again. The math is more complicated than Wt=Wt−1+vtktT, but may be simplified to it.
Now that we have simplified the masked attention mechanism, we can start developing it again. And there is an obvious room for experiments: choosing a different lower triangular (never look into the future) mask M, instead of the ascetic 0-1 causal mask. But first, we need to deal with the emerging inefficiency.
In the previous section, we were actually very lucky to work with a simple 0-1 causal mask M. In general case, the recurrent trick doesn’t work anymore: y1=ϕ(qt)(m11⋅ϕ(k1T))ϕ(q1)(m11⋅ϕ(k1T)v1)
The coefficients mts are different now, and there’ll be no recurrent formula relating y3 to y2. Thus, for each t we’ll have to compute the sum from scratch, making the complexity asymptotic quadratic by L again, instead of linear.
The answer is that we won’t use just any masks M, only special, “good” ones. Namely, the ones that can be quickly multiplied (not elementwise!) on other matrices. To see how we can benefit from this property, let’s understand how to calculate
For the further discussion, it will be useful for us to color indices instead of blocks:
Now, we are ready to formulate the four-step algorithm:
Step 1. Take K and V and create a three-dimensional tensor Z with ztij=ktivtj
(Each axis is marked with its length.) This step takes O(Ld2) time and memory.
Note that if we sum this tensor over the magenta axis t, we’ll get the product KTV: t∑ktivtj=t∑(KT)itvtj=(KTV)ij Step 2. Now, we multiply M on this tensor (not elementwise!) in such a way that M is multiplied on each “column” of Z along the magenta axis t.
This gives us exactly
t∑mstktivtj=t∑mst(KT)itvtj
Let’s denote the result by H. Now, it’s only left to multiply everything on qt which will be done in two more steps. Step 3a. Take Q and multiply elementwise on each j=const layer of H:
This results in: qsit∑mstktivtj
This step requires O(Ld2) time and memory. Step 3b. Sum the resulting tensor along the i axis:
This step also requires O(Ld2) time and memory. And we obtain exactly the thing we need:
i∑qsit∑mstktivtj
The most crucial step is the second one, and we deliberately omitted its complexity analysis. A naïve estimate is:
O(L2) for each matrix multiplication,
Repeating it d2 times,
So it’s a whopping O(L2d2). But we will take not just any M. We will take M such that the complexity of multiplying M on a vector is O(RL) for some constant (and not very large) R.
For example, if M is the 0-1 causal matrix, multiplication on it boils down to computing the cumulative sums, and it is done in O(L) time. But there are many more options for structured matrices with fast vector multiplication.
In the next section, we’ll discuss one of such matrix types — semiseparable matrices, which will turn out to be tightly connected with state space models.
Let’s recall that (discretized) state space models (SSM) are a class of sequential models connecting 1-dimensional input xt, r-dimensional hidden state ht, and 1-dimensial output ut in the following way: ht=Atht−1+Btxt, yt=Ctht+Dtxt
So, in the discretized form, it’s just a fancy linear RNN with a skip connection. Moreover, for the rest of the story, we’ll even forget about skip connections by setting Dt=0.
Now, let’s write our SSM as a single matrix multiplication y=Mx,
where
So, M is lower triangular, just like the attention masks we discussed earlier.
This type of matrix comes with a great boon:
An L×L lower triangular matrix whose elements may be presented in this way can be stored using O(rL) memory and has a matrix-vector multiplication complexity of O(rL) instead of the default O(L2) (see this paper for details).
This means that every state space model gives rise to a structured attention mask M which can be used in an efficient transformer model with linearized attention.
The curious thing, though, is that a semiseparable matrix M is already quite complex and expressive, even without all the query-key-value machinery around it. Moreover, it may itself be a masked attention mechanism. We’ll see this in the next section.
Here, we finally arrive at one of the central results of the Mamba 2 paper.
Let’s again consider y=Mu, where u=u(x) is a function of the inputs and M is a semiseparable matrix. Moreover, let’s consider a very special case where each At is a scalar matrix: At=atI. Then the formulas become especially simple:
is just a number. Moreover, we can stack Ci and Bi in matrices B and C such that:
Now, we’ll also need the matrix
Then it’s easy to check that Mu=(G⊗CB)u
Have we seen this somewhere? Indeed! It’s masked attention with:
G for the mask,
C for Q, the query matrix,
B for KT, the transposed key matrix,
u for V, the value matrix.
In classic SSMs, B and C are constants, but Mamba made them data-dependent, reinforcing the alignment. This correspondence between specific state space models and masked attention was introduced in the Mamba 2 paper as state space duality.
Further reading. The idea of using matrix mixers instead of more sophisticated architectures is not new, with one of the notable early examples being MLP-Mixer which employed MLPs instead of convolution or attention for spatial mixing in CV tasks.
Although current research is mostly focused on LLMs, there are also several papers proposing non-transformer, matrix mixture architectures for encoder models. Examples include:
FNet from Google research a with matrix mixer M based on Fourier transform,
Hydra, which, among other ideas, proposed an adaptation of semiseparable matrices for non-causal (non-triangular) working mode.
As always, I encourage you to check out our Practical Generative AI course, where the discussions often serve as a source of knowledge and inspiration for posts like this one.