Lately I've been wondering... is this a problem, or a strength?
It might be a fallacy to compare how LLMs "think" with how humans think. But humor me for a second. When you are speaking, each time you emit a word, you are not attending to every previous word in your sentence (like transformers), rather you have a state in your mind that represents the grammar and concepts, which is continuously updated as you speak (more similar to SSMs).
Similarly, when you read a book, every time you read a word, you are not attending to every previous word in the book. Your model of "the book" is rather a fuzzy/approximate state that is updated with new information every time a new word appears. Right? (I'm sorry I know this is very handwavy and psuedoscientific but bear with me).
Ok, so if (big if) you feel like the above is true, then to match human-type language modelling, SSMs seem more human-like than transformers.
BUT... then aren't transformers strictly better in terms of accuracy? Because a transformer never "forgets" information, as long as it is within the context window, because it revisits that information every time it emits a new token.
So let's say we can remove the "quadratic attention" problem of transformers with SSMs. That's a nice training/inference performance boost. But... look at where we got with "naive" attention. GPT 4, Claude 3. It's not like we're hitting a wall with quadratic attention. It's absurdly more expensive than SSMs, but GPUs certainly aren't getting slower. If all AI work stops now, and only hardware improves, it wouldn't be long until GPT4 could run on local hardware, right, provided Moore's law?
/end rant, not really sure what my point was, I'm not against SSMs (they're cool) but rather I'm wondering if the SOTA will ever be SSM when attention is so damn good
It probably depends. But an idea I've been playing with: because transformers have such a strong ability for recall during inference, they might be introducing a strong inductive bias for memorization as opposed to generalization. Why bother to build a complete world model when you can just attend to the answer? The global minimum in loss (at least for the training dataset) would use those memorizing and interpolating circuits over those that generalize well. This seems consistent with LLMs as they exist today: superhuman at recall, very mediocre at reasoning. Though, for what it's worth, existing SSSMs haven't yet shown they can outperform (or even match) transformers when it comes to reasoning.
If this hypothesis were true, you might expect to see grokking in state space models more quickly than in transformer models.
(Even if it's hard to train transformers to generalize, superhuman recall is still incredibly valuable, and likely a hybrid system would offer the best of both worlds.)
The innovation is not the speed, but the lack of recursion or iteration. Humans, even accomplished ones, have to reread sections and really 'internalize' ideas before being able to summarize and very few humans can -- in a single attempt -- generate perfect speech. Most of us speak and unknowingly revise our own speech as we go along. Unlike transformers, that speak confidently, we start making a sentence and then decide halfway through its not going where we like. Then we start it over again, and by the powers of human attention, no one seems to really notice.
Transformers Are just insanely complicated and expensive to train.
I don't know how the reasoning part comes to us but if we could implant that capability to a transformer model then it would end up pretty good.
I'll just add the observation that when we do this it's largely based on feedback receive from the recipient (well, so long as you're talking-with as opposed to talking-at) - we're paying attention to how the audience is paying attention or not, any small facial tics that might betray skepticism or agreement and so on. I'm looking forward to interacting with an LLM that pairs an emotion-vector along with each token it has previously produced.
hume.ai goes a long way analyzing audio, just a matter of time before they're ingesting realtime facial cues to also incorporate their audience's reaction in their choice of what to say next
But that's the efficiency-effectiveness tradeoff that we have to make: given that compute is limited, would we prefer attention over shorter sequences or SSMs over longer sequences? The answer is probably "well, it depends on your use case" - I can definitely see reasons for both!
A fairly compelling thought for me is hybrid architectures (Jamba is a recent one). Here you can imagine having perfect recall over recent tokens and lossy recall over distant tokens. E.g. if the AI is generating a feature-length film, you "could imagine having Attention look at the most recent frames for short-term fluidity and an SSM for long-term narrative consistency" (quote from the OP)
> Lately I've been wondering... is this a problem, or a strength?
Exactly. There are lot of use cases where perfect recall is important. And earlier data may be more or less incompressible, such as if an LLM is working on a large table of data.
Maybe we'll end up with different architectures being used for different applications. E.g. simple chat may be OK with an RNN type architecture.
I've also seen people combine Mamba and Transformer layers. Maybe that's a good tradeoff for some other applications.
It's a strength; fundamentally it's impossible to achieve the same degree of accuracy with a sub-quadratic attention mechanism: https://arxiv.org/abs/2209.04881 (unless the Strong Exponential Time Hypothesis is false, which is very unlikely, like P=NP).
I was wondering the same thing. I understand, why the initial developers of this method declared it as a strength. Still I think it's a problem, too:
If the Tranformer reads this sentence:
A equals B
It understands, that B comes after A and therefore A equals B. But how does it learn that after A comes B and therefore B equals A.
I am referring to the logical problems, that most (all?) modern language models suffer of.
Current models are horrendously inefficient though, so with architectural improvements we'll have something of that capability far sooner on weaker hardware.
We are not hitting a wall, but a slope. Hardware improvements will not make up for it indefinitely. Software will have to make up for it, but the problem is that it costs millions of dollars to hit compile.
I was exactly doing this until late in my youth. until I learnt people do it sequentially. But it is doable to create connections and pick the sensible case. Not the most relaxing thing.
For the uninitiated (like me), apparently it stands for State Space Models.
It shows more than not that we are also parrots
For example "all previous tokens can be passed to the current token." That seems like a poorly constructed sentence. A token is not a function and it's not an algorithm either... How can you pass tokens to a token? This type of ambiguous language in academic papers makes it hard to read... Maybe the phrase 'every token has an association with every other previously encountered token' would be better? Or every token is used to compute the token vector for each token... I don't know, all I can do is guess the meaning of the word 'passed'. They want us to infer and fill in the gaps with our own assumptions. It assumes that we are primed to think in a certain highly constrained way...
For some reason a lot of academia around AI is littered with such imprecise language. They choose to use niche concepts and repurposed wording that their own small community invented rather using words and ideas that are more widely understood but which would convey the same information.
Rational people who aren't directly involved in those fields who generally resist jumping to conclusions will struggle to understand what is meant because a lot of those words and ideas have different interpretations in their own fields.
I studied machine learning at university and wrote ANNs from scratch and trained them and even I find the language and concepts around LLMs too ambiguous. I'd rather just ask ChatGPT.
One thing that bothers me is that the community has moved away from relating concepts to neurons, interconnections, input layers, hidden layers and output layers. Instead, they jump straight into vectors and matrices... Pretending as though there is only one way to map those calculations to neurons and weights. But in fact, this abstraction has many possible interpretations. You could have fully connected layers or partially connected layers... Maybe you need a transformer only in front of the input layer or between every layer... So many possibilities.
The entire article means little if considered in isolation outside of the context of current configurations of various popular frameworks and tools.
that statement is meaningfully different from "all previous tokens can be passed to the current token". and both really makes sense if you understand attention mechanisms.
Do you pass information from other tokens to a token in the sense that each token processes information from other tokens? A token isn't a processing unit AFAIK, it's just a word part. The processing is not the responsibility of the token itself. My understanding is that tokens may be associated with each other via an external structure but not passed to each other. Or maybe they meant a token vector? And the token vector contains information from related tokens? It's unclear.
To me, 'passed' means data passed to a function or algorithm for processing. It's confusing unless a token is a function or algorithm.
My point is that this language only makes sense if you are already up to date in that field.
I see what you did there
https://news.ycombinator.com/item?id=39501982 https://www.kolaayonrinde.com/blog/2024/02/11/mamba.html
Adding: this resurgence in Mamba in general is also due to some actual sota progress with SSM like the new AI21 lab released this week [1] and likely to see others merging different architecture layers (this is a 52B MoE with 12B params active during inference blending both Mamba and transformers)
>As the first production-grade model based on Mamba architecture, Jamba achieves an unprecedented 3X throughput and fits 140K context on a single GPU.