Why do large language models (LLMs) need massive distributed training across nodes -- if the models fit in one GPU and larger batch only decreases the variance of gradients?
tldr: assuming for models that don't need sharding across nodes, why do we need (massive) distributed training if the models (e.g. CLIP, Chinchilla, even really large GPTs e.g. CLIP fits in a V100 32GB) fit in one GPU and larger batch only decreases the variance of gradients (but not expose ore tokens or param updates)? A larger batch doesn't necessarily mean we train on "more data/tokens" -- or at least that doesn't seem to be wrt SGD like optimizers.
Intuitively, it feels that if we had a larger batch size then we have more tokens to learn about -- but knowing some theory of optimization and what SGD like algorithms actually do -- a larger batch size only actually decreases the variance of gradients. So to me it's not clear why massie distributed training is needed -- at all unless the model is so large that it has to be shared across nodes. In addition, even if the batch was "huge" -- we can only do a single gradient update.
I feel I must be missing something obvious hence the question given how pervasive massive distributed training is.
In addition some toy training curves with V100s & T5's show me there is very little if any benefit in additional GPUs
In addition, it seems from nonGPT we know small batch sizes are sufficient to train (reference https://github.com/karpathy/nanoGPT but I did ask Karpathy directly to confirm https://github.com/karpathy/nanoGPT/issues/58).
I am missing something obvious, but I wanted to clear this up in my head since it seems to be a foundation thing in training foundation models.
Related to the previous, I've also been unsure about the role of the batch size in training LLMs compared to traditional deep learning. In traditional deep learning when we used epochs to train, a model the larger the batch size the quicker we could go through an epoch -- so the advice I received (e.g. approximate advice by Ruslan Salakhutdinov's at the Simon's institute for deep learning tutorials) was to make the batch size large. Intuitively, the larger the batch size the more data the model sees per iteration. But mathematically this only really improves the variance of the gradient -- which isn't immediately obvious is what we want (I've done experiments and seen papers where noisy gradients lead to better models). It is clear too the that the larger the context size the better (for everything, but for the sake of this conv it's better for training) -- whenever possible. But context size is totally different from batch size. So my question is, how does distributed training, especially at the node level help at all if batch size isn't really the helping factor (which might be a wrong assumption)? So the only role for distributed training I see is if the model is to large to fit in 1 node -- since I'm arguing there is no point to make the batch size too large (I'd guess 64-32 is fine due to the CLT).
What am I missing? Empirical answers are fine! Or any answers are fine!
Related:
- cross quora: https://www.quora.com/unanswered/Why-do-large-language-models-LLMs-need-massive-distributed-training-across-nodes-if-the-models-fit-in-one-GPU-and-larger-batch-only-decreases-the-variance-of-gradients
- cross reddit: https://www.reddit.com/r/learnmachinelearning/comments/113whxu/why_do_llms_need_massive_distributed_training/