Attention Explained to Ordinary Programmers
Most explanations of Large Language Models (LLMs) rely heavily on linear algebra and tensor libraries. While that complexity is essential for high-performance implementations, most of the LLM architecture can be described precisely without needing any linear algebra.1
And that’s exactly what I’ll be doing in this blog post. I’ll be explaining the attention mechanism with nothing but completely standard Python code. By the end, we’ll have an easy to understand attention implementation capable of running Llama2, albeit slowly, at around one token per second, on a CPU with 64 GB of memory (no GPU required).
Disclaimer: I work on ML Infrastructure at Google. All opinions are my own, not necessarily those of my employer.
Reading time: 7 minutes
Tokenization
But before we dive deep into attention, we quickly need to understand the concept of tokenization. Instead of processing strings directly, LLMs split their input into a sequence of tokens. A simple tokenization could look the one below; where the tokens (e.g. ALWAYS) are just arbitrary numbers:2
tokenize("Always answer with Haiku") = [ALWAYS, ANSWER, WITH, HAIKU]With this out of the way, let’s dive right into attention.
Attention Scores
Attention is basically just a search function, that takes a list of inputs, and a score function, and returns the inputs whose score is non-zero.
For example, if you have the input sequence [APPLES, ARE, A, KIND, OF] and you have a score function that returns 1 for nouns (i.e. things, places, or people) and 0 for everything else, then you’d get:
attend_to([APPLES, ARE, A, KIND, OF], is_noun) = APPLESI’m sure you can think of a million easy ways to implement this. The way we’ll be implementing this is by multiplying all the tokens by their score and summing them up:
= is_noun(APPLES) * APPLES +
is_noun(ARE) * ARE +
is_noun(A) * A +
is_noun(KIND) * KIND +
is_noun(OF) * OF+
= 1 * APPLES +
0 * ARE +
0 * A +
0 * KIND +
0 * OF
= APPLESAnd here’s the full code:
def attend_to(inputs, score):
result = 0
for input in inputs:
result += score(input) * input
return resultAnd believe it or not, this is pretty close to what an LLM actually does! Internally, LLMs will call the attend_to function many times to figure out what token to generate next; e.g. the LLM will attend to nouns (among many other things) in a sentence to figure out what the sentence is all about, and then generate a reasonable completion, like FRUIT.
While we provided a concrete implementation of the score function here (i.e. is_noun), LLMs usually learn their score functions during the training phase. We won’t go into detail in this post, but these learned functions are very simple. They don’t have loops, don’t call APIs, don’t have side-effects, and don’t have much control flow. They mostly just do a bunch of number crunching.
Attention Values
The above approach works well if the score function returns 1 for exactly one input, and 0 for everything else. But when there are multiple inputs with a non-zero score, the output may become nonsensical (in our example, we’re literally adding APPLES and ORANGES).
To fix this, we pass one additional argument to the attention function. The value argument is a function that maps each of the inputs to values that actually make sense to add up. For example, we may pass the sweetness function, which tells us how sweet a particular noun is (1 being very sweet, and 0 being not sweet at all). So if we want to know the sweetness of a fruit-salad, we’d write:
attend_to([APPLES, ORANGES, AND, FIGS, ARE], is_noun, sweetness)
= is_noun(APPLES) * sweetness(APPLES) +
is_noun(ORANGES) * sweetness(ORANGES) +
is_noun(AND) * sweetness(AND) +
is_noun(FIGS) * sweetness(FIGS) +
is_noun(ARE) * sweetness(ARE) +
= 1 * 0.6 +
1 * 0.4 +
0 * 0.0 +
1 * 0.9 +
0 * 0.0
= 1.9With real LLMs, just like the score function, the value function is usually quite simple, and learned during the training phase.
And that's pretty close to what we want! Just two more small additions: 1) For context, we also pass the last element of the inputs sequence to the score function; and 2) we normalize the output by the total score, so our sweetness level doesn’t grow higher and higher the more fruits we add to our salad. Hmm, yummy 0.633 (= 1.9 / 3) sweetness fruit salad.
With those in mind, here is our final attention implementation:
def attend_to(inputs, score, value):
result = 0
total_score = 0
last_input = inputs[-1]
for input in inputs:
result += score(input, last_input) * value(input)
total_score += score(input, last_input)
return result / total_scoreAnd that’s it! You can literally use this attention function verbatim in your LLM implementation and it works. Here is an example Llama2 implementation, which when prompted with “Always answer with Haiku. I am going to Paris, what should I see?“, it prints:
Eiffel Tower high
Love locks on bridge embrace
River Seine's gentle flow
But how can it be so simple, you may ask. If you’ve previously read about attention, you probably heard terms like embeddings, keys, and queries, softmax, KV caches. What about all of that? Good questions. Here we go:
Embeddings
We’ve been passing tokens directly into the attention function, real LLMs use embeddings instead. An embedding replaces a token with the properties of that token, for example the embedding structure for a token could be:
@dataclass
class Embedding:
is_fruit
is_animal
is_noun
is_plural
sweetnessWith embedding of the token APPLES being:
embed(APPLES) = Embedding(
is_fruit = 1.0
is_animal = 0.0
is_noun = 1.0
is_plural = 1.0
sweetness = 0.6
)While we provide a concrete embedding implementation here, real LLMs usually learn the embedding. With embeddings, our score and value functions become simpler — which also means they become much easier for the LLM to learn:
def is_noun(embedding):
return embedding.is_noun
def sweetness(embedding):
return embedding.sweetnessLlama2’s 7B model embed tokens into vectors of 4096 floats (instead of the 5 floats in our example); and embeds values (returned by the value function) into vectors of 128 floats (instead of the single sweetness float used in our example)3.
Keys and Queries
An LLM’s score functions usually have some fixed structure to them, where only some aspects can be learned. This fixed structure is the following:
def score(input, last_input):
combine(key(input), query(last_input))First, the score function computes some properties of the input with the learned key function; then computes some properties of the last_input in the input sequence with the learned query function; and lastly combines the two using a fixed combine function.
Both the key and query function are usually quite simple and learned during the training phase. In Llama2 7B both the key and query function return vectors of 128 floats.
Softmax
The combine function is fixed (not learned) and the only part of the attention mechanism which I don’t know how to explain without linear algebra; sorry. The combine function takes the key and value vectors, and then performs the dot product to turn them into a single floating point number; followed by a division by a constant (the square root of 128 for Llama2 7B). I’m not quite sure why this dot product and division is the right thing to do, and maybe there are other strategies that could make sense, but this is what it is.
Finally, the combine function raises e to the power of the score computed so far. This has the effect that the difference between scores gets amplified, so that the value with the highest score gets the vast majority of the attention (even if the score difference to the other values is relatively small). This property of putting most of the weight behind the maximum, is why this mechanism is referred to as softmax.
Here’s the full combine implementation for Llama2 7B:
def combine(key, value):
math.exp((torch.dot(query, key) / math.sqrt(128)KV Cache
LLMs call the attention mechanism many times for similar input sequences, and so a lot of computational resources can saved by using a cache to memoize the key and value function calls. This cache is called the key-value cache (aka KV cache).
Summary
For easy reference, here’s a copy of all the important code:
def attend_to(inputs, score, value):
result = 0
total_score = 0
last_input = inputs[-1]
for input in inputs:
result += score(input, last_input) * value(input)
total_score += score(input, last_input)
return result / total_score
def combine(key, value):
math.exp((torch.dot(query, key) / math.sqrt(N)Conclusion
I hopefully provided you with a good understanding of the attention mechanism in LLMs. And yet, there are still many things to talk about. How does attention get used inside an LLM, e.g. what are encoders/decoders, what are transformers? What is the linear algebra used to implement the learned key, value, and query functions? What is normalization, what is batching? How do you actually make this run fast on a GPU using matrix multiplication, how do you train it efficiently on 20K GPUs (this is what I actually do on a daily basis)? How do you generate images? I’d love to write more about these. If you’d like to read them, please subscribe or something, so I know this is useful, and I’m not just screaming into the void :-)
The trick to make this possible is to treat matrix multiplication of tensors, as the invocation of linear functions.
The tokenization above is a bit simplistic; e.g. real tokenizations usually splits words like “fishing” into the tokens FISH and ING.
When values are a vector, multiplication of the value vector V with the score S is interpreted as scalar multiplication, i.e. every element of V is multiplied by S.


