12

Why LLMs learn so fast during inference, but, ironically, are so slow during training? That is, if you teach an AI a new concept in a prompt, it will learn and use the concept perfectly and flawless, through the whole prompt, after just one shot. Yet, if you train it in just a single sample, it will not influence its behavior at all - it will essentially forget. Why can't RNNs use whatever is happening during inference, rather than gradient descent, to update its weights and, thus, learn? In other words, can't the attention mechanism itself be used to update weights, rather than some cost function?

MaiaVictor
  • 355
  • 1
  • 9
  • One easy answer is that, during training, a model is learning grammar, conventions, spelling, word choice, etc. but *not* concepts. During inference, a model still doesn't know about concepts, but it can tell that a cluster of tokens is novel, and various attention techniques keep that cluster's probability high. – Corbin Apr 02 '23 at 12:56

4 Answers4

15

There is huge difference between what is happening with the information during training and during inference and one can not be used for the other.

Let me start with an analogy to the human brain (which is not a very good analogy and has several flaws, but it gives an intuition, I will late build on):
If I tell you "A palindrome is a number, a word or a sentence that reads the same forward and backward. One example palindrome is 15651" then you will now know what a palindrome is and you will be able to work with this new concept. If I tell the same to a newborn baby, it will not. It takes years to bring a baby to the point that it is able to understand my previous sentence.

Enough of the analogy. Let's have a look at the RNNs:
RNNs are neural networks with weights. Unlike some other networks, they have something that you can call an internal state. Weights and internal state are different:

  • The internal state serves as a memory that stores information of the previously processed information, e.g. a new concept that was explained earlier.
  • The weights define how new information changes the internal state and how input and internal state produce some output.

So untrained neural networks typically have randomly initialized weights. New input will then cause kind of arbitrary output und updates of the internal states. So if you give a new concept to an untrained neural network as an input, it will not the new concept, but update the internal state into meaningless numbers and produce random gibberish as output.

If you train the model, the weights are updates to serve a given purpose. This can be a relatively simple model, e.g. that detects whether a tweet is positive or negative. In this case, the network would only be trained on tweets and the internal state would only represent the positivity of the previous words and maybe if the last word was "not" to distinguish "I am happy" from "I am not happy". Probably much more detailed and not so easy to interpret, but something like this.

But if you build an LLM, you will train on much more heterogeneous data and for tasks that will involve to understand new concepts. In this case the weights of the model will be learned in a way, that the network can process new concepts and store the essence of the concept in the internal state.

In short: teaching the network new concepts (which is an update of the internal state) can only happen, because before a long training of the weights enables the LLM to do so.

Avoiding Backpropagation: There is some recent work that explores new ways of training that avoid the costly backpropagation and instead preforms two forward steps. The forward-forward Algorithm, but to my knowledge it is not used for LLMs, yet. And even if, one would still need to train it on a huge amount of data to learn some weights that allow the network to process new concepts as input.

Laurel
  • 105
  • 3
Broele
  • 551
  • 2
  • 12
  • So why do we need to train the LLM with the entire Internet corpus, rather than training it with a limited corpus until it can accept new corpus from its prompt and then giving the rest of the corpus as the prompt? – user253751 Mar 31 '23 at 14:38
  • _"training it with a limited corpus until it can accept new corpus from its prompt"_ But the *until* never happens if you train it with a limited corpus. If the training set isn't big enough it is much more likely that the training will result in weights that make it output human-like writing only in the training set but not when given new prompts. – JiK Mar 31 '23 at 21:06
  • But if you teach a human what a palindrome is, only once, they will be able to use that concept for the rest of their life, assuming they paid attention. On LLMs, if they never heard about palindromes before, and if you try to teach them what a palindrome is by feeding a single explanation in a fine-turning step, it will absolutely not learn it, at all. And that's the problem: for humans, inference and training is unified; we can "update weights" to learn a new concept forever with a single example, just like LLM can learn that concept temporarily in a prompt, but only for its duration. – MaiaVictor Apr 01 '23 at 14:01
  • @JiK do you think ChatGPT is only able to learn from input that appears in its training set? If I tell it "squoozles masquerade priffins. What does a squoozle masquerade?" you think it can't answer that because it's never heard of a squoozle or a priffin before? – user253751 Apr 01 '23 at 14:03
  • For an LLM, the difference is again the difference between state and weights. While the state is typically resettet after each session, the weight stay the same. It would theoretically be possible to keep one state between the sessions and then the model would learn new concepts during inference and remember them. But then the network would need a much bigger state to store all the new concepts. And risk would be that some concepts overwrite others – Broele Apr 01 '23 at 14:06
  • @user253751 as I wrote before, there are two ways to store information for the RNN: weights and state. The state covers all the information givin in the current session, but has limited size (so during a longer session, knowledge might get lost). The weight store all the knowledge and "abilities" learned during training. This training need so much input data for two reasons: 1. to store knowledge to answer questions and 2. to learn to handle a high variety of task. If you only train on english wikipedia, it will know a lot, but it can only create wikipedia-like articles. – Broele Apr 01 '23 at 14:18
  • @Broele and yet the one trained on English wikipedia can still achieve: "When I say 'wer bist du?' you say 'ich bin ChatGPT!'" The training data doesn't seem to specifically limit what it can accomplish, in the way you would generally expect it to. Perhaps the difference is that it could never learn to generalize handling of *instructions* written in German. – user253751 Apr 01 '23 at 14:20
  • @user253751: I have to admit that I am not sure which model you mean. Typically, there is more than just the wikipedia involved. And note, that LLMs have typically multiple phases of training. Including pre-training, unsupervised training and refinements. Nevertheless, the rule of thumb that more heterogeneous data leads to more generalized abilities still holds. – Broele Apr 01 '23 at 14:37
  • @Broele there probably isn't one only trained on English wikipedia, but if there was, it would be able to follow that instruction. We know because the one trained on the whole Internet is still able to recognize made-up words or keyboard mashing – user253751 Apr 01 '23 at 14:54
  • @user253751 It is a difference, whether you tell the model how to reply to a german greeting (i.e. provide knowledge during inference) or the model has the ability to speak german. Of course, you could start to provide the model with an extensive German course during inference before speaking in German with it. But better train it on German texts as well. And - I am speculating now - it might likely be that as model trained also on texts from language courses is much better able to process German courses than a model only trained on wikipedia (Training gives both knowledge and abilities) – Broele Apr 02 '23 at 14:24
  • @user253751 _"do you think ChatGPT is only able to learn from input that appears in its training set?"_ Not in the way you use the word "learn". Let me rephrase: For a given training set, there are different possible sets of weights that make the model score well on the training set. In general, there are much fewer possible sets of weights that allow the model perform well on new input. When you increase the training set, it is more and more likely that a set of weights that scores well in the training set is also one that performs well in general. – JiK Apr 03 '23 at 10:25
  • Actually @user253751 your gibberish example about priffins currently leads to chatGPT asking to explain because those are not commonly used words and the meaning is unclear. And if you did take enough time to explain any define everything, That training would disappear once it slides out of ChatGPTs recent context window or you start a new session. – IronSean Apr 04 '23 at 20:23
10

They are not "learning" during inference at all.

Learning is the process of updating the weights of the model (to lower loss). This does not happen during inference. The model weights stay the same.

When you are "teaching an AI a new concept", you are just giving it some context, which improves its ability to answer what happens next. But weights do not get updated. That is the computationally expensive part.

In a meta context, I guess you can call this "learning", and this is effectively what is being done with Microsoft's new venture with GPT. They are letting it search things on the fly. Probably a lot of interesting research will surface using such techniques soon :)

shatz
  • 134
  • 3
  • 3
    I think both answers here miss my point though! It is absolutely capable of learning concepts very fast at inference, and, yes, it is *not* updating its weights. Do you see the problem there? (; – MaiaVictor Mar 31 '23 at 14:30
  • 6
    It is capable of "using context" to correctly solve problems that other models have to be trained to solve. Often, you can even present descriptions of completely new algorithms as "context" and it can execute them. This seems a lot like training. – user253751 Mar 31 '23 at 14:37
  • It is "learning" in the sense that you are giving it some context. But we associate the term "learning" with how humans learn something. For some intuition, if you would consider a human to be a neural network, its as if we are neural networks that have no ability to control whether or not our weights get updated. So we are always "learning". But in this context, the definition of "learning" is not the same. Because if you switch to another context (like another chat in chatgpt), it will not remember the context. – shatz Mar 31 '23 at 14:38
  • If you really want to, you can say the model has "learned" how to use context. But that is as far as I would go. – shatz Mar 31 '23 at 15:06
  • @MaiaVictor: since you think that our answers miss the point, I will try to update my answer and hopefully meet your expectations a bit better. In short (and for my taste too short): training = learning how to learn new concepts + learning a huge knowledge base inference = learn new concepts fast, but only if the training has been done before – Broele Mar 31 '23 at 20:43
  • 1
    @shatz It is not at all difficult to imagine equipping a GPT model with a really long context window - or wiring several attention layers in parallel together, which is how the original Transformer worked. That model fed each attention layer not just the outputs from the previous layer of the sentence it was generating, but *also* fed it the outputs from the *input* sentence, processed completely separately. This was used for a translation task but it's also not hard to imagine accessing a fact database in a similar sort of way. – user253751 Mar 31 '23 at 21:31
  • @Broele I wasn't saying that in a condescending tone, both answers were great! I just wanted to point that perhaps there is some way to adjust weights based not on gradient descent, but on the attention mechanism itself, that we're somehow missing. Thanks for the work you put in it. – MaiaVictor Mar 31 '23 at 22:26
  • @MaiaVictor: I was not understanding it in a condescending tone :). But it sounds like your question is not fully answered. If that's the case, we might try to improve the answer. – Broele Apr 01 '23 at 12:27
3

It's kind of like short-term memory versus long-term memory. Giving a language model a small amount of information at inference time allows it to use that information, and so you might say that the model has "learned" that information, but this "learning" isn't really useful in the long term.

For RNNs, the problem is that the state vector only contains a limited amount of information. You can tell an RNN something once, but as you give it more information, it will forget what you told it previously. So if you have a large amount of information that you want your RNN to be able to access, then providing that information as input during inference won't do the trick; you need to train it.

For transformers, the problem is that the amount of time it takes the model to process a token of input is proportional to the number of tokens it's already processed. If you have just a small amount of information that you want the transformer to learn, that's not a problem, but if you try to give a transformer a very large amount of information as input, that will make inference very slow.

Note that language models are sometimes permanently "taught" things by means of input instead of training. For example, it's been reported that ChatGPT and Bing Chat have a hard-coded prompt that's always present at the beginning of the input, and which contains some information about what the developers want the model to do.

Tanner Swett
  • 151
  • 6
3

As pointed out by others, what you call "learning" at inference, is nothing more than providing more context. The model can indeed memorize in its short-term, but it is only working for the current task at hand. You suggest that we could make a model with an infinite contextual memory, but then it would mix up all tasks together. It would literally be like if you had to recite all the numbers you ever calculated or counted or saw before starting a new calculation.

Hence, contextualization is only useful for short-term tasks, and it works only thanks to the slow learning phase you have to do the first time around, which is more formally called the "convergence process".

So, what you are looking for is in fact to make the convergence process faster, and more precisely a one-shot or zero-shot learning. If you don't just look at LLMs (Large Language Models) and RNNs (Recursive Neural Networks), there are a lot of other AI models that can do one-shot or even zero-shot learning, such as memory models like grippon-berrou neural network. One-shot learning can learn the first time they see an example, and generalize over it. Zero-shot learning can even learn without being presented some examples, by generalizing from others, or by transferring knowledge from another field.

For example, Text2Video-Zero is a recently published text to video generator, which did NOT learn from any video, but instead reused the weights from Stable Diffusion trained on still images. What this algorithm does is that it can cleverly generalize learning from still images into a coherent sequence of images with the same style, hence mimicking motion, with no additional cost. Of course, it's not completely zero-shot, because it has to be provided with a Stable Diffusion weights model first, but essentially zero-shot learning means that you can reuse one model that was made for one purpose for another purpose, for free (ie, you can directly infer, no need to re-learn anything).

Technically, One/Zero-shot learning typically requires another kind of architecture, more brain-like (ie, with discrete 0/1 synaptic weights). The long convergence processes are usually required by networks using floating weights (ie, the McCulloch-Pitts neurons). Because floating weights are not at all biologically plausible, they are a mathematical abstraction that synthesizes several biological functions of biological neural networks into fewer, more amenable to programming abstractions.

Likewise, convolution layers in CNNs (convolutional neural networks) are another abstraction of how biological systems integrate big populations of neurons, but here we can use a much smaller population of artificial neurons, and use more optimized instructions sets to do the same work as the brain does. You have to keep in mind that for a lot of purposes in AI, current computers are much less efficient than the human brain, hence why all these highly synthetic reproductions, more optimized for the machine but very remote from how real biological systems work, are necessary. Here, long convergence (ie, long learning) is an unavoidable artifact from how we model our artificial neurons and synapses, with floating numbers instead of discrete (binary), and with mathematical functions for integration instead of analog biological integration (which is both more fine grained and simpler than numerical functions, see for example the videos by Veritasium about analog computers, biological systems have similar properties and advantages).

RNNs are a kind of the opposite approach and problem, because they use a more biologically plausible property, recursivity, but the problem is that we have a hard time defining artificial systems that are efficient at learning recursive networks. So here, it's the opposite of what can be observed with CNNs and LLMs: the long convergence is due to current science providing inefficient learning algorithms when recursivity is involved. The last few years saw tremendous progress on this, with very clever algorithms, but it's still very far from how biological systems can neatly manage recursivity.

All that is to say that, to answer directly your question, why the current LLM and RNN models can't learn in zero/one-shot from the get-go: it's because nobody found a way to mathematically formulate such a model. Maybe someone will be able to in the near future, maybe it will take decades, but for now, it's the slow convergence LLM and RNN models that work, it's the ones that provide you with the hyped tools such as ChatGPT.

Personally, I think we won't get there until we find how analog biological neural system work, and then we need to develop new computer technologies to mimic those. There is already a lot of work towards these, with biological neurons reprogramming by ARN signalling or mixing them with silicon neurons, but it's still far from the "real deal". There are at least hundreds of different types of neurons, and there are many other neural cells types with not completely understood functions. We are far from fully understanding biological neural systems, but progress is continuous and steady.

Disclaimer: I am both an AI researcher and a clinical neuroscientist and I studied some computational neuroscience.


/EDIT: A small update to extend my explanation above for the technically and philosophically inclined ones: learning at its most fundamental level can be defined as the ability of a system to modify its structure to reflect some input signal, and memory being the system itself that can modify its structure according to input signals. In biological systems, there are two types of memory: short-term and long-term. Recent artificial recursive neural network models try to mimic this, with the very famous LSTM model (Long-Short Term Memory), itself a precursor of the GPT models. By convention, in machine learning we call "learning" the tweaking of the weights, ie, the long-term memory. But there is also indeed a short-term memory which has its own weights, but AI researchers don’t call this process learning, although it technically is by all standards, the only difference being the exact method used and the length of time the memory is retained.

And just like there are models that modify/learn short-term memory at inference but not long-term memory, there are models that tweak their long-term memory at inference, notably bayesian models, as often used for weather forecasting.

So why LLMs and RNNs learn fast during inference is because they are designed to only learn short-term memory, so that the big lot of weights of long-term memory were learnt beforehand. But future improvements of the tech may very well allow to design networks that also learn long-term memory "online", in real-time, in a stochastic manner with a guarantee of convergence.

gaborous
  • 466
  • 3
  • 4
  • 1
    Thanks, this was the answer I was looking or. Just to be sure: do we know that McCulloch-Pitts Neuron is how a human neuron works? I.e., our brains are essentially networks of M-P neurons? So, the only question left is how these neurons update their weights and prune or form new synapses? – MaiaVictor Apr 09 '23 at 15:04
  • @MaiaVictor The opposite: we know brain neurons (not just humans') are NOT working like McCulloch-Pitts. But they still are a good enough model for most purposes. But there are more [biologically accurate/plausible neurons models](https://en.wikipedia.org/wiki/Biological_neuron_model), those that can generate action potentials, one of the oldest being Hodgkin–Huxley. But then they are computationally less efficient, so they are better to study human neurons in-silico, but not to make the next big AI app that everyone will want to use to boost their productivity. See computational neuroscience. – gaborous Apr 10 '23 at 00:21
  • @MaiaVictor Also modern neurobiology ditched the old neuron-centric view: we learn more and more about all the other neural structures that are essential if not more than neurons for the brain to work: synapses, glial cells, neuromodulators, neurotransmitters, etc. Also, all have lots of various subtypes. You can see the brain as being just like the whole human population: there are all sorts of colors and hairstyles and face-body shapes, and they build houses of all shapes and designs, there is an unfathomable diversity distributed around the world, adapted to their respective environments. – gaborous Apr 10 '23 at 00:37