5

In the transformer (or GPT/decoder only), at the end of the decoder blocks but before the final linear layer you have X vectors (for the X tokens at the input of the decoder). We then want to compute the probabilities for the next token of the sequence - what do we then feed to the linear layer? Is it the last embedding corresponding to the hidden state of the last token in the input sequence?

I've seen some tutorials on youtube on how to make mini gpts but I never quite understood why they feed the entire X vectors/hidden states at the end of the decoder blocks to the linear layer and not just the last vector/hidden state... Wouldn't you have X probability distributions when in reality you only want one? And if we do want the X probability distributions then wouldn't we be completely missing the point of the masked self attention since we would be trying to predict words that are already in the input sequence, so essentially "cheating"?

2 Answers2

1

Welcome to AI stack exchange!

I understand the confusion. Inference (next token prediction) seems really counterintuitive and inefficient for transformers. And it is! The transformer is very efficient during training because it can be parallelized. It is, however, inefficient at inference because it cannot be parallelized.

For transformer inference, you feed the context (your prompt) to the transformer model. It predicts the next word for each of the words in the prompt, but you only need the prediction for the last one.

A bit of pseudocode might help in understanding how a transformer can be used to generate new tokens:

# Start with some context of tokens
context = ...

# Generate new tokens
for i in range(N_TOKENS_TO_GENERATE):
    prediction = transformer(context)                # Get predictions for context
    next_token = multinomial(prediction.get_last())  # Sample from multinomial distribution
    context = concatenate((context, next_token))     # Create new context 

Now, this is the intuitive way of doing it. There are most likely tons of small things you can do to optimize all of the stuff and make inference more efficient. However, you cannot get around having to feed the context in every time you add a new word/token. This is also why an application such as ChatGPT is generating stuff word for word.

A small note on the side: you talk about 'hidden-states' in the transformer, as if there is a recurrence going on (such as in GRUs/LSTMs/RNNs). However, transformers have no such recurrence and hidden-states and operate solely using the concept of attention (hence the paper's title 'attention is all you need', alluding to the fact that they don't use recurrence).

Hope this helps :)

Robin van Hoorn
  • 1,810
  • 7
  • 32
  • Thank you! What about in training? What is the point of feeding the context embeddings to the linear layer and not only the last linear layer, like it's done in https://github.com/karpathy/nanoGPT? Also, some of the papers I read/people I talk to sometimes talk in hidden states in transformers (maybe out of habit because of RNNs), so that is why I also mentioned hidden states ahah – Miguel Carvalho Apr 21 '23 at 21:37
  • im not sure i understand your first followup question. Could you elaborate a little bit on what you mean? Regarding the hidden-state, its very understandable but there is a notable difference, because in transformers you dont have a 'state' after a prediction which you can reuse. You have to recompute everything for the new context. – Robin van Hoorn Apr 22 '23 at 08:10
  • So what I meant was this: why are all the context embeddings at the end of all decoder blocks fed into the linear layer and not only the last context embedding at the end of the decoder blocks? Wouldn't it be better to only use a single vector at the linear layer instead of the whole context vectors? In other words, in the original [gpt paper](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf), why is it that the matrix `h_n` is fed to the last linear layer and not only the last token embedding (`h_n[-1]`)? – Miguel Carvalho Apr 23 '23 at 14:26
  • Because everything together represents the embedding, not just the last bit. Its a feedforward network, not a recurrent network. – Robin van Hoorn Apr 23 '23 at 17:23
  • But how does that even work? Btw I'm referring to equation 2 in that paper. If you multiply a matrix by a matrix you get a matrix... How do you get a probability distribution over all tokens for the next token in the sequence? – Miguel Carvalho Apr 23 '23 at 17:31
  • Think of it as a simple MLP for a multi-class classification problem. You simply enter the embedding (which is always equal in size), and out comes a probability distribution for which word comes next. – Robin van Hoorn Apr 23 '23 at 17:38
  • Exactly - but if you feed the entire context embeddings just after the decoder blocks (which is a matrix) to the linear layer, then you will be predicting the next token for every single token in the context, which I think is not what you want - I thought you only wanted to predict the next token AFTER the whole context/after the last token in the context... – Miguel Carvalho Apr 23 '23 at 17:46
  • No. You feed the entire context through the blocks of the transformer, then to the linear layer, and you will be predicting the next token for **the complete context**. In a transformer, one forward pass of a complete context = one prediction. – Robin van Hoorn Apr 23 '23 at 17:49
  • Let us [continue this discussion in chat](https://chat.stackexchange.com/rooms/145539/discussion-between-miguel-carvalho-and-robin-van-hoorn). – Miguel Carvalho Apr 23 '23 at 21:27
0

No - each next-token prediction comes from a single one of the output vectors as you suspected. It has to, because otherwise there is no way to parallelize the predictions during training using a consistent set of parameters.

My understanding is from https://transformer-circuits.pub/2021/framework/index.html, section "High Level Architecture":

What you are referring to as "X vectors", they refer to as T(t) in that figure, of shape [n_context, n_vocab], which are logits. The input to the decoder, t is the [n_context, n_vocab] shaped tensor of one-hot encoded tokens (see the "Notation" section of that article at the end).

Note also that the original "Attention is All You Need" paper re-uses the embedding matrix as a transpose to de-embed each of the X vectors coming out of the final multi-head attention decoder layer. (See section 3.4 in the original paper)