Transformer is hard to describe with analogies, and TBF there is no good explanation why it works, so it may be better to just present the mechanism, "leaving the interpretation to the viewer". Also, it's simpler to describe dot products as vectors projecting on one another
Why does, 'mat' follow from 'the cat sat on the ...' because 'mat' is the most frequent word in the dataset; and the NN is a model of those frequencies.
Why is 'London in UK' "known" but 'London in France' isnt? Just because 'UK' much more frequently occurs in the dataset.
The algorithm isnt doing anything other than aligning computation to hardware; the computation isnt doing anything interesting. The value comes from the conditional probability structure in the data. -- that comes from people arranging words usefully, because they're communicating information with one another
P(next_word|previous_words) is ridiculously hard to estimate in a way that is actually useful. Remember how bad text generation used to be before GPT? There is innovation in discovering an architecture that makes it possible to learn P(next_word|previous_words), in addition to the computing techniques and hardware improvements required to make it work.
> NNs are a stat fitting alg learning a conditional probability distribution, P(next_word|previous_words).
They are trained to maximise this, yes.
> Their weights are a model of this distribution.
That doesn't really follow, but let's leave that.
> Why does, 'mat' follow from 'the cat sat on the ...' because 'mat' is the most frequent word in the dataset; and the NN is a model of those frequencies.
Here's the rub. If how you describe them is all they're doing then a sequence of never-before-seen words would have no valid response. All words would be equally likely. It would mean that a single brand new word would result in absolute gibberish following it as there's nothing to go on.
Let's try:
Input: I have one kjsdhlisrnj and I add another kjsdhlisrnj, tell me how many kjsdhlisrnj I now have.
Result: You now have two kjsdhlisrnj.
I would wager a solid amount that kjsdhlisrnj never appears in the input data. If it does pick another one, it doesn't matter.
So we are learning something more general than the frequencies of sequences of tokens.
I always end up pointing to this but OthelloGPT is very interesting https://thegradient.pub/othello/
While it's trained on sequences of moves, what it does is more than just "sequence a,b,c is followed by d most often"
Or if you just want to say that NNs are used as a statistical model here: Well, yea, but that doesn't really tell you anything. Everything can be a statistical model.
E.g., you could also say "this is exactly the way the human brain works", but it doesn't really tell you anything how it really works.
I can explain that car engines 'just' convert gasoline into forward motion. But if a the person hearing the explanation is hoping to learn what a cam belt or a gearbox is, or why cars are more reliable now than they were in the 1970s, or what premium gas is for, or whether helicopter engines work on the same principle - they're going to need a more detailed explanation.
This recent $10,000 challenge is super super interesting imho. https://twitter.com/VictorTaelin/status/1778100581837480178
State of the art models are doing more than “just” predicting the probability of the next symbol.
So, "mat" follows "the cat sat on the" where we understand the entire worldview of the dataset used for training; not just the next-word probability based on one or more previous words ... it's based on all previous meaning probability, and those meaning probablility and so on.
You're confidently incorrect by oversimplifying all LLMs to a base model performing a completion from a trivial context of 5 words.
This is tantamount to a straw man. Not only do few people use untuned base models, it completely ignores in-context learning that allows the model to build complex semantic structures from the relationships learnt from its training data.
Unlike base models, instruct and chat fine-tuning teaches models to 'reason' (or rather, perform semantic calculations in abstract latent spaces) with their "conditional probability structure", as you call it, to varying extents. The model must learn to use its 'facts', understand semantics, and perform abstractions in order to follow arbitrary instructions.
You're also confabulating the training metric of "predicting tokens" with the mechanisms required to satisfy this metric for complex instructions. It's like saying "animals are just performing survival of the fittest". While technically correct, complex behaviours evolve to satisfy this 'survival' metric.
You could argue they're "just stitching together phrases", but then you would be varying degrees of wrong:
For one, this assumes phrases are compressed into semantically addressable units, which is already a form of abstraction ripe for allowing reasoning beyond 'stochastic parroting'.
For two, it's well known that the first layers perform basic structural analysis such as grammar, and later layers perform increasing levels of abstract processing.
For three, it shows a lack of understanding in how transformers perform semantic computation in-context from the relationships learnt by the feed-forward layers. If you're genuinely interested in understanding the computation model of transformers and how attention can perform semantic computation, take a look here: https://srush.github.io/raspy/
For a practical example of 'understanding' (to use the term loosely), give an instruct/chat tuned model the text of an article and ask it something like "What questions should this article answer, but doesn't?" This requires not just extracting phrases from a source, but understanding the context of the article on several levels, then reasoning about what the context is not asserting. Even comparatively simple 4x7B MoE models are able to do this effectively.
What about cases that are not present in the dataset?
The model must be doing something besides storing raw probabilities to avoid overfitting and enable generalization (imagine that you could have a very performant model - when it works - but it sometimes would spew "Invalid input, this was not in the dataset so I don't have a conditional probability and I will bail out")
My mental justification for attention has always been that the output of the transformer is a sequence of new token vectors such that each individual output token vector incorporates contextual information from the surrounding input token vectors. I know it's incomplete, but it's better than nothing at all.
I thought the general consesus was: "transformers allow neural networks to have adaptive weights".
As opposed to the previous architectures, were every edge connecting two neurons always has the same weight.
EDIT: a good video, where it's actually explained better: https://youtu.be/OFS90-FX6pg?t=750&si=A_HrX1P3TEfFvLay
attention(Q,K,V) = softmax (Q K^T √ dK ) @ V
is just half a row; the multi-head, masking and positional stuff just toppings
we have many basic algorithms in CS that are more involved, it's amazing we get language understanding from such simple math
In my experience, with very few notable exceptions (e.g. Feynmann), researchers are the worst when it comes to clearly explaining to others what they're doing.
I'm at the point where I'm starting believe that pedagogy and research generally are mutually exclusive skills.
I'm curious what things other videos did worse compared to 3b1b?
I like how he was able to avoid going into the weeds and stay focused on leading you to understanding. I remember another video where I got really hung up on positional encoding and I felt like I could t continue until I understood that. Or other videos that overfocus on matrix operations or softmax, etc.
1. The standard terminology is "meh" at most. The word "attention" itself is just barely intuitive, "self-attention" is worse, and don't get me started about "key" and "value".
2. The key papers (Attention is All You Need, the BERT paper, etc.) are badly written. This is probably an unpopular opinion. But note that I'm not diminishing their merits. It's perfectly compatible to write a hugely impactful, transformative paper describing an amazing breakthrough, but just don't explain it very well. And that's exactly what happened, IMO.
3. The way in which these architectures were discovered was largely by throwing things at the wall and seeing what sticked. There is no reflection process that ended on a prediction that such an architecture would work well, which was then empirically verified. It's empirical all the way through. This means that we don't have a full understanding of why it works so well, all explanations are post hoc rationalizations (in fact, lately there is some work implying that other architectures may work equally well if tweaked enough). It's hard to explain something that you don't even fully understand.
Everyone who is trying to explain transformers has to overcome these three disadvantages... so most explanations are confusing.
I wouldn't say so. Historically it's quite common. Maxwell's EM papers used such convoluted notation it it quite difficult to read. It wasn't until they were reformulated in vector calculus that they became infinitely more digestible.
I think though your third point is the most important; right now people are focused on results.
There's a reason The Illustrated Transformer[1] was/is so popular: it made the original paper much more digestible.
1. good communication requires an intelligence that most people sadly lack
2. because the type of people who are smart enough to invent transformers have zero incentive to make them easily understandable.
most documents are written by authors subconsciously desperate to mentally flex on their peers.
Often, the disseminating medium is a one-sided, like a video or a blog post, which doesn't help, either. A conversational interaction would help the expert sense why someone outside the domain find the subject confusing ("ah, I see what you mean"...), discuss common pitfalls ("you might think it's like this... but no, it's more like this...") etc.
In quantum mechanics, the state of your entire physical system is encoded as a very high dimensional normalized vector (i.e., a ray in a Hilbert space). The evolution of this vector through time is given by the time-translation operator for the system, which can loosely be thought of as a unitary matrix U (i.e., a probability preserving linear transformation) equal to exp(-iHt), where H is the Hamiltonian matrix of the system that captures its “energy dynamics”.
From the video, the author states that the prediction of the next token in the sequence is determined by computing the next context-aware embedding vector from the last context-aware embedding vector alone. Our prediction is therefore the result of a linear state function applied to a high dimensional vector. This seems a lot to me like we have produced a Hamiltonian of our overall system (generated offline via the training data), then we reparameterize our particular subsystem (the context window) to put it into an appropriate basis congruent with the Hamiltonian of the system, then we apply a one step time translation, and finally transform the resulting vector back into its original basis.
IDK, when your background involves research in a certain field, every problem looks like a nail for that particular hammer. Does anyone else see parallels here or is this a bit of a stretch?
One might model the human brain as a FSM as well, but I’m not sure I’d call the predictive ability of the brain an implementation detail.
In one part he talks about a thought experiment modeling the universe as a multidimensional cellular automata. Where fundamental particles are nothing more than the information they contain. And particles colliding is a computation that tells how that node and the adjacent nodes to update their state.
Way out and not saying there's anything truth to it. But it was a really interesting and fun concept to chew on.
Either way, I find Planck time/energy to be a very spooky concept.
Didn’t make the “what was going on then in physics “ connection until now.
Expanding on that, you could imagine that the intent of the sentence to complete (figuring out the murderer) would have to be captured in the first attention passes so that other layers would then be able to integrate more and more context in order to extract that information from the whole context. Also, it means that the forward passes for previous tokens need to have extracted enough salient high-level information already since you don't re-compute all attention passes for all tokens for each next token to predict.
You don't? I imagine the attention maps could be pretty different between n and n+1 tokens.
Edit: Or maybe you just meant you don't compute attention Σ(n) times for each new token?
Not with this GPT. The context size would not allow keeping attention to the total meaning of more than 2048 tokens (as reflected in the transformed embedding of that context's last token). For a substantial part of a novel, it would require a much larger context size with then presumably will need a higher dimensional embedding/semantic space.
Which means the embedding model must do a lot of the lifting to be able to accurately represent meaning across long contexts so well. Now I want to know more about how those models are derived.
(is my understanding)
each prediction is based on a weighted context of all previous tokens, not just the immediately preceding one.
I suppose that when you each element in the vector weighs 16 bits then the space is immense and capable to have a novel in a point.
This complements the detailed description provided by 3blue1brown
[1] https://medium.com/@daniellefranca96/gpt4-all-details-leaked...
Otherwise it feels like lots of machinery created out of nowhere. Lots of calculations and very little intuition.
Jeremy Howard made a comment on Twitter that he had seen various versions of this idea come up again and again - implying that this was a natural idea. I would love to see examples of where else this has come up so I can build an intuitive understanding.
1) The initial seq-2-seq approach was using LSTMs - one to encode the input sequence, and one to decode the output sequence. It's amazing that this worked at all - encode a variable length sentence into a fixed size vector, then decode it back into another sequence, usually of different length (e.g. translate from one language to another).
2) There are two weaknesses of this RNN/LSTM approach - the fixed size representation, and the corresponding lack of ability to determine which parts of the input sequence to use when generating specific parts of the output sequence. These deficiencies were addressed by Bahdanau et al in an architecture that combined encoder-decoder RNNs with an attention mechanism ("Bahdanau attention") that looked at each past state of the RNN, not just the final one.
3) RNNs are inefficient to train, so Jakob Uszkoreit was motivated to come up with an approach that better utilized available massively parallel hardware, and noted that language is as much hierarchical as sequential, suggesting a layered architecture where at each layer the tokens of the sub-sequence would be processed in parallel, while retaining a Bahdanau-type attention mechanism where these tokens would attend to each other ("self-attention") to predict the next layer of the hierarchy. Apparently in initial implementation the idea worked, but not better than other contemporary approaches (incl. convolution), but then another team member, Noam Shazeer, took the idea and developed it, coming up with an architecture (which I've never seen described) that worked much better, which was then experimentally ablated to remove unnecessary components, resulting in the original transformer. I'm not sure who came up with the specific key-based form of attention in this final architecture.
4) The original transformer, as described in the "attention is all you need paper", still had a separate encoder and decoder, copying earlier RNN based approaches, and this was used in some early models such as Google's BERT, but this is unnecessary for language models, and OpenAI's GPT just used the decoder component, which is what everyone uses today. With this decoder-only transformer architecture the input sentence is input into the bottom layer of the transformer, and transformed one step at a time as it passes through each subsequent layer, before emerging at the top. The input sequence has an end-of-sequence token appended to it, which is what gets transformed into the next-token (last token) of the output sequence.
One great way to improve on this bottleneck is a new-ish idea called Ring Attention. This is a good article explaining it:
https://learnandburn.ai/p/how-to-build-a-10m-token-context
(I edited that article.)
In Unsloth, memory usage scales linearly (not quadratically) due to Flash Attention (+ you get 2x faster finetuning, 80% less VRAM use + 2x faster inference). Still O(N^2) FLOPs though.
On that note, on long contexts, Unsloth's latest release fits 4x longer contexts than HF+FA2 with +1.9% overhead. So 228K context on H100.
Worth noting that a lot of the visualization code is available in Github.
https://github.com/3b1b/videos/tree/master/_2024/transformer...
I've seen several tutorials, visualisations, and blogs explaining Transformers, but I didn't fully understand them until this video.
It's clever way to encode "position in sequence" as some kind of smooth signal that can be added to each input vector. You might appreciate this detailed explanation: https://towardsdatascience.com/master-positional-encoding-pa...
Incidentally, you can encode dates (e.g. day of week) in a model as sin(day of week) and cos(day of week) to ensure that "day 7" is mathematically adjacent to "day 1".
How well do Transformers perform given learned embeddings but only randomly initialized decoder weights? Do they produce word soup of related concepts? Anything syntactically coherent?
Once a well trained high dimensional representation of tokens are established. can they learn KQ/MLP weights significantly faster?
Attention has won.[a]
It deserves to be more widely understood.
---
[a] Nothing has outperformed attention so far, not even Mamba: https://arxiv.org/abs/2402.01032
I have a spatial intuition for transformers as a sort of analog to a message passing network over a "leaky graph" in an embedding space. If each token is a node, its key vector sets the position of an outlet pipe that it spews value to diffuse out into the embedding space, while the query vector sets the position of an input pipe that sucks up value other tokens have pumped out into the same space. Then we repeat over multiple attention layers, meaning we have these higher order semantic flows through the space.
Seems to make a lot of sense to me, but I don't think I've seen this analogy anywhere else. I'm curious if anybody else thinks of transformers in this way. (Or wants to explain how wrong/insane I am?)