Training large models

Keywords

  • Data Parallel
    • Backprop happens on one of the GPUs
    • That GPU needs to maintain the optimizer state
  • Distributed Data Parallel
    • Gradients are synced after forward pass
    • Every GPU must maintain the optimizer states
    • Gradients are summed via all-reduce operation
  • Sharded
    • Fully Sharded Data Parallel (FSDP)
    • ZeRO-3
    • Key insight:
      • we can decompose the all-reduce operations into separate reduce-scatter and all-gather operations
  • Model Parallel
    • (?) tend to be much more inefficient than sharded training

Resources

Code

  • Fairscale
  • Fairseq
  • PyTorch Distributed
  • torch.distributed.checkpoint
    • Saving a distributed model in SPMD style
    • What the heck is SPMD??