You don't need tensors to understand attention
If you’ve ever tried to understand Large Language Models (LLMs), you likely ran into a wall of tensors and matrix multiplication. It’s easy to assume that without the heavy math, any explanation is doomed to be hand-wavy or imprecise. The surprising truth, however, is that we can describe the attention mechanism with total precision—without relying on tensors or linear algebra.1
And that’s exactly what we’ll be doing in this blog post. We are going to build a linear-algebra-free implementation of attention in pure Python, capable of running Llama 2 (albeit slowly) at around one token per second on a CPU with 64 GB of memory (no GPU required).
Disclaimer: I work on ML Infrastructure at Google. All opinions are my own, not necessarily those of my employer.
Reading time: 7 minutes
Introduction
LLM are essentially autocomplete on steroids, so a call to an llm function may look like this:
llm(”Apples are a kind of”) == "Fruit"Internally, LLMs implement this autocomplete by using an attention mechanism to focus on certain parts of the input. In this example, the LLM may be focusing (among many other things) on the noun (i.e. people, places, things) of the input sentence. The somewhat simplified internals of the LLM may thus look something like this:
llm(input):
...
noun = attend_to(input, is_noun)
...
if noun == "Apples":
result = "Fruit"
...In this post we’ll talk about how this attend_to function is implemented.
But before we dive deep into attention, we quickly need to understand the concept of tokenization. Instead of processing strings directly, LLMs split their input into a sequence of tokens. A simple tokenization could look the one below; where the tokens (e.g. ALWAYS) are just arbitrary numbers:2
llm([APPLES, ARE, A, KIND, OF]) == FRUITWith this out of the way, let’s dive right into attention.
Attention Scores
Attention is basically just a search function, that takes a list of inputs, and a score function, and returns the inputs whose score is non-zero.
For example, if you have the input sequence [APPLES, ARE, A, KIND, OF] and you have a score function that returns 1 for nouns (i.e. things, places, or people) and 0 for everything else, then you’d get:
attend_to([APPLES, ARE, A, KIND, OF], is_noun) = APPLESI’m sure you can think of a million easy ways to implement this. The way we’ll be implementing this is by multiplying all the tokens by their score and summing them up:
= is_noun(APPLES) * APPLES +
is_noun(ARE) * ARE +
is_noun(A) * A +
is_noun(KIND) * KIND +
is_noun(OF) * OF+
= 1 * APPLES +
0 * ARE +
0 * A +
0 * KIND +
0 * OF
= APPLESAnd here’s the full code:
def attend_to(inputs, score):
result = 0
for input in inputs:
result += score(input) * input
return resultAnd believe it or not, this is pretty close to what an LLM actually does! Internally, LLMs will call the attend_to function many times to figure out what token to generate next; e.g. the LLM will attend to nouns to figure out what the sentence is all about, and then generate a reasonable completion, like FRUIT.
While we provided a concrete implementation of the score function here (i.e. is_noun), LLMs usually learn their score functions during the training phase. We won’t go into detail in this post, but these learned functions are very simple. They don’t have loops, don’t call APIs, don’t have side-effects, and don’t have much control flow. They mostly just do a bunch of number crunching.
Attention Values
The above approach works well if the score function returns 1 for exactly one input, and 0 for everything else. But when there are multiple inputs with a non-zero score, the output may become nonsensical (in the example below, we’re literally adding APPLES and ORANGES).
To fix this, we pass one additional argument to the attention function. The value argument is a function that maps each input to a value that actually makes sense to add up.
For example, we could pass sweetness as the value function — which returns 1 for very sweet tokens, and 0 for tokens that are not sweet at all. The attention function would then again pay attention to the input tokens that are nouns, but instead of adding up the nouns themselves, it would add up the nouns’ sweetness.
So to figurer out the sweetness of a fruit-salad, we’d write:
attend_to([APPLES, ORANGES, AND, FIGS, ARE], is_noun, sweetness)
= is_noun(APPLES) * sweetness(APPLES) +
is_noun(ORANGES) * sweetness(ORANGES) +
is_noun(AND) * sweetness(AND) +
is_noun(FIGS) * sweetness(FIGS) +
is_noun(ARE) * sweetness(ARE) +
= 1 * 0.6 +
1 * 0.4 +
0 * 0.0 +
1 * 0.9 +
0 * 0.0
= 1.9With real LLMs, just like the score function, the value function is usually quite simple, and learned during the training phase.
And that's pretty close to what we want! Just two more small additions: 1) We normalize the output by the total score, so our sweetness level doesn’t grow higher and higher the more fruits we add to our salad. Hmm, yummy 0.633 (= 1.9 / 3) sweetness fruit salad. 2) For context, we also pass the last element of the inputs sequence to the score function.
With those in mind, here is our final attention implementation:
def attend_to(inputs, score, value):
result = 0
total_score = 0
last_input = inputs[-1]
for input in inputs:
result += score(input, last_input) * value(input)
total_score += score(input, last_input)
return result / total_scoreAnd that’s it! You can literally use this attention function verbatim in an LLM implementation and it works. Here is an example Llama2 implementation, which when prompted with “Always answer with Haiku. I am going to Paris, what should I see?“, it prints:
Eiffel Tower high
Love locks on bridge embrace
River Seine's gentle flow
But how can it be so simple, you may ask. If you’ve previously read about attention, you probably heard terms like embeddings, keys, and queries, softmax, KV caches. What about all of that? Good questions. Here we go:
Embeddings
We’ve been passing tokens directly into the attention function; real LLMs use embeddings instead. An embedding replaces a token with the properties of that token, for example, the embedding structure for a token could be:
@dataclass
class Embedding:
is_fruit
is_animal
is_noun
is_plural
sweetnessThe embedding of APPLES would then be:
embed(APPLES) = Embedding(
is_fruit = 1.0
is_animal = 0.0
is_noun = 1.0
is_plural = 1.0
sweetness = 0.6
)Here we’ve provided a concrete embedding of APPLES. Real LLMs usually learn the embedding for every single token. With embeddings, our score and value functions become simpler — which also means they become much easier for the LLM to learn:
def is_noun(embedding):
return embedding.is_noun
def sweetness(embedding):
return embedding.sweetnessLlama2’s 7B model embeds tokens into vectors of 4096 floats (instead of the 5 floats in our example); and embeds values (returned by the value function) into vectors of 128 floats (instead of the single sweetness float used in our example)3.
Keys and Queries
Unlike the value function which is completely learned, an LLM’s score function has some fixed structure, and only some aspects can be learned. This fixed structure is the following:
def score(input, last_input):
combine(key(input), query(last_input))First, the score function computes some properties of the input with the learned key function; then computes some properties of the last_input in the input sequence with the learned query function; and lastly combines the two using a fixed combine function.
Both the key and query function are usually quite simple and learned during the training phase. In Llama2 7B both the key and query function return vectors of 128 floats.
Softmax
The combine function is fixed (not learned) and the only part of the attention mechanism which I can’t explain without using linear algebra; sorry. The combine function takes the key and value vectors, and then performs the dot product to turn them into a single floating point number; followed by a division with a constant (in Llama2 7B, that constant is the square root of 128). I’m not quite sure why this dot product and division is the right thing to do, and maybe there are other strategies that could make sense, but this is what it is.
Finally, the combine function raises e to the power of the score computed so far. This has the effect that the difference between scores gets amplified, so that the value with the highest score gets the vast majority of the attention (even if the score difference to the other values is relatively small). This property of putting most of the weight behind the maximum, is why this mechanism is referred to as softmax.
Here’s the full combine implementation for Llama2 7B:
def combine(key, value):
math.exp((torch.dot(query, key) / math.sqrt(128)KV Cache
LLMs call the attention function many times for similar input sequences, and so a lot of computational resources can saved by using a cache to memoize the key and value function calls. This cache is called the key-value cache (aka KV cache).
Summary
For easy reference, here’s a copy of all the important code:
def attend_to(inputs, score, value):
result = 0
total_score = 0
last_input = inputs[-1]
for input in inputs:
result += score(input, last_input) * value(input)
total_score += score(input, last_input)
return result / total_score
def combine(key, value):
math.exp((torch.dot(query, key) / math.sqrt(N)Conclusion
I hopefully provided you with a decent explanation of the attention function inside LLMs. And yet, there are still many things to talk about. How does the attention function really get used inside an LLM, e.g. what are encoders/decoders, what are transformers? What is the linear algebra used to implement the learned key, value, and query functions? What is normalization, what is batching? How do you actually make this run fast on a GPU using matrix multiplication, how do you train it efficiently on 20K GPUs (this is what I actually do on a daily basis)? How do you generate images? I’d love to write more about these. If you’d like to read them, please subscribe or something, so I know this is useful, and I’m not just screaming into the void :-)
The trick to make this possible is to treat matrix multiplication of tensors as the invocation of regular functions, that happen to be linear.
The tokenization above is a bit simplistic; e.g. real tokenizations usually splits words like “fishing” into the tokens FISH and ING.
When values are a vector, multiplication of the value vector V with the score S is interpreted as scalar multiplication, i.e. every element of V is multiplied by S.


