Training large models
Keywords
- Data Parallel
- Backprop happens on one of the GPUs
- That GPU needs to maintain the optimizer state
- Distributed Data Parallel
- Gradients are synced after forward pass
- Every GPU must maintain the optimizer states
- Gradients are summed via all-reduce operation
- Sharded
- Fully Sharded Data Parallel (FSDP)
- ZeRO-3
- Key insight:
- we can decompose the all-reduce operations into separate reduce-scatter and all-gather operations
- Model Parallel
- (?) tend to be much more inefficient than sharded training
Resources
- ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
- (FAIR article) Fully Sharded Data Parallel: faster AI training with fewer GPUs
- (fairseq github examples) FSDP
- (PyTorch Blog) Introducing PyTorch Fully Sharded Data Parallel (FSDP) API
- https://towardsdatascience.com/sharded-a-new-technique-to-double-the-size-of-pytorch-models-3af057466dba
- https://www.youtube.com/watch?v=qRZrVNNe3gQ
- https://huggingface.co/docs/accelerate/usage_guides/big_modeling
- Training OPT (Susan Zhang)
- Rethinking PyTorch Fully Sharded Data Parallel (FSDP) from First Principles
- FairScale Documentation
- (PyTorch Tutorial -- beginner) DDP in PyTorch
- (PyTorch Tutorial -- intermediate) PyTorch Distributed
- (PyTorch Tutorial -- intermediate) Single-Machine Model Parallel Best Practices
- (PyTorch Tutorial) PyTorch Distributed Overrview
- torchrun
- An Introduction to Distributed Deep Learning (Blog; 2016)
- The Technology Behind BLOOM Training
Code
- Fairscale
- Fairseq
- PyTorch Distributed
- torch.distributed.checkpoint
- Saving a distributed model in SPMD style
- What the heck is SPMD??