Skip to main content

Distributed Optimizers

Author: Wenxuan Tan, Junwen Duan, Renjie Mao

Related Paper

Introduction

Apart from the widely adopted Adam and SGD, many modern optimizers require layer-wise statistics to update parameters, and thus aren't directly applicable to settings where model layers are sharded across multiple devices. We provide optimized distributed implementations with minimal extra communications, and seamless integrations with Tensor Parallel, DDP and ZeRO plugins, which automatically uses distributed optimizers with 0 code change.

Optimizers

Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(NMF) to reduce memory footprint. CAME improves by introducting a confidence matrix to correct NMF. GaLore further reduces memory by projecting gradients into a low-rank space and 8-bit block-wise quantization. Lamb allows huge batch sizes without lossing accuracy via layer-wise adaptive update bounded by the inverse of its Lipschiz constant.

Hands-On Practice

We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs. Note that even if you're not aware of distributed optimizers, the plugins automatically casts yours to the distributed version for convenience.

step 1. Import libraries

from transformers import LlamaModel, LlamaConfig
from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
import colossalai
import torch

step 2. Initialize Distributed Environment and Parallism Group

We need to initialize distributed environment. For demo purpose, we use colossal run --nproc_per_node 4. You can refer to Launch Colossal-AI

colossalai.launch_from_torch()

step 3. Initialize Module and Optimizer

Build our model. We created an MLP using two Linear Layer.

# Init Llama from huggingface
configuration = LlamaConfig()
model = LlamaModel(configuration).cuda()
criterion = lambda x: x.mean()
dist_optim = DistributedAdaFactor(model.parameters())

step 4.Init Booster

plugin = HybridParallelPlugin(tp_size=2, zero_stage=2, pp_size=1, enable_all_optimization=True)
booster = Booster(plugin=plugin)
# You should also pass in your own dataset.
model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion)

step 5.Train Your Model

steps = 10
for step in range(steps):
input_ids = torch.ones(1, 100, device="cuda", dtype=torch.int)
attention_mask = input_ids.clone()
outputs = model(input_ids.cuda(), attention_mask.cuda())
loss = criterion(outputs.last_hidden_state)
booster.backward(loss, dist_optim)
dist_optim.step()
dist_optim.zero_grad()

GaLore special handling

For GaLore, we need to specify projection rank for each parameter group and quantization & paged optimizer params. Please refer to bitandbytes for quantization details. Support for ZeRO is underway.

from colossalai.nn.optimizer.galore import get_galore_param_groups
from colossalai.nn.optimizer import DistGaloreAwamW
optim = DistGaloreAwamW(
get_galore_param_groups(model, decay=1e-2, rank=8),
lr=lr,
betas=(beta1, beta2),
eps=eps,
nbits=8,
percentile_clipping=100,
block_wise=True,
min_8bit_size=4096,
)

Plugin compatibility

Optimizer/PluginHybrid Parallel PluginLow Level Zero PluginTorch DDP PluginGemini PluginMoe Hybrid Plugin
Lamb✔️✔️✔️
GaLore✔️✔️✔️
Adafactor✔️✔️✔️
CAME✔️✔️✔️

API Reference

class
 

colossalai.nn.DistributedAdaFactor

(params, lr = None, eps = (1e-30, 0.001), clip_threshold = 1.0, decay_rate = -0.8, beta1 = None, weight_decay = 0.0, scale_parameter = True, relative_step = True, warmup_init = False)
Description
function
 

setup_distributed

(tp_group: ProcessGroup = None, dp_group: ProcessGroup = None, shard_to_working_param: typing.Dict = {}, padding_map = None, use_zero: bool = True)
Parameters

tp_group -- The devices group for tensor parallel; dp_group -- The devices group for data parallel;

  • shard_to_working_param (Dict) -- ZeRO 2 feeds the optimizer a sharded param view as grads are sharded. This maps from id(view) to working params used in forward & backward. padding_map -- An empty interface placeholder; use_zero -- Whether or not to use zero;
Description
Setup process groups for TP and ZeRO 2. Inject features to the Optimizer
function
 

step

(closure = None)
Parameters
  • closure (callable, optional) -- A closure that reevaluates the model and returns the loss.
Description

Performs a single optimization steps

class
 

colossalai.nn.DistributedLamb

(params, lr = 0.001, betas = (0.9, 0.999), eps = 1e-06, weight_decay = 0, bias_correction = True)
Parameters
  • params (iterable) -- iterable of parameters to optimize or dicts defining parameter groups
  • lr (float, optional) -- learning rate (default: 1e-3)
  • betas (Tuple[float, float], optional) -- coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))
  • eps (float, optional) -- term added to the denominator to improve numerical stability (default: 1e-8)
  • weight_decay (float, optional) -- weight decay (L2 penalty) (default: 0)
Description
Implements the Lamb algorithm, with extra support for ZeRO 2 and Tensor Parallel. Proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. It's recommended to use this with HybridParallelPlugin/ZeRO plugin and booster, which will take care of setup_distributed. Example with 4 devices: >>> optim = DistributedLamb(model.parameters(), lr=1e-3) >>> proc_mesh = ProcessGroupMesh(tp_size, zero_size) >>> tp_group = proc_mesh.get_group_along_axis(0) >>> dp_group = proc_mesh.get_group_along_axis(1) >>> optim.setup_distributed(tp_group, dp_group)

.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962

function
 

setup_distributed

(tp_group: typing.Optional[torch.distributed.distributed_c10d.ProcessGroup] = None, dp_group: typing.Optional[torch.distributed.distributed_c10d.ProcessGroup] = None, shard_to_working_param: typing.Optional[typing.Dict] = {}, padding_map = None, is_zero: typing.Optional[bool] = False)
Parameters
  • tp_group (dist.ProcessGroup) -- Tensor Parallel process group
  • dp_group (dist.ProcessGroup) -- ZeRO 2 process group
  • shard_to_working_param (Dict) -- ZeRO 2 feeds the optimizer a sharded param view as grads are sharded. This maps from id(view) to working params used in forward & backward. padding_map -- An empty interface placeholder
  • is_zero (bool) -- Whether to use ZeRO 2.
Description
Assign process groups for TP and ZeRO 2.
function
 

step

(closure = None)
Parameters
  • closure (callable, optional) -- A closure that reevaluates the model and returns the loss.
Description
Performs a single optimization step.
class
 

colossalai.nn.DistGaloreAwamW

(params, lr = 0.01, betas = (0.9, 0.999), eps = 1e-08, weight_decay = 0.01, nbits = 8, min_8bit_size = 4096, percentile_clipping = 100, block_wise = True, is_paged = False, args = None)
Parameters
  • params (iterable) -- iterable of parameters to optimize or dicts defining parameter groups.
  • lr (float, optional) -- learning rate. (default: 1e-3)
  • betas (Tuple[float, float], optional) -- coefficients used for computing running averages of gradient and its norm. (default: (0.9, 0.999))
  • eps (float, optional) -- term added to the denominator to improve numerical stability. (default: 1e-6)
  • weight_decay (float, optional) -- weight decay (L2 penalty) (default: 0.01) nbits -- Number of bits for quantization optim states. Only 32 and 8 are supported.
  • min_8bit_size (int, defaults to 4096) -- The minimum number of elements of the parameter tensors for 8-bit optimization.
  • percentile_clipping (int, defaults to 100) -- Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
  • block_wise (bool, defaults to True) -- Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
  • is_paged (bool, defaults to False) -- Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
  • args (dict, optional) -- quantization-related arguments. If passed, will override all quantization args above.
Description
Implements Galore, a optimizer-agonistic gradient compression technique on 8-bit AdamW. It largely compresses gradient via low-rank projection and is claimed to be insensitive to hyperparams like lr. Supports Tensor Parallel and ZeRO stage 1 and 2 via booster and plugin. Proposed in `GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection` https://arxiv.org/abs/2403.03507
function
 

setup_distributed

(tp_group: typing.Optional[torch.distributed.distributed_c10d.ProcessGroup] = None, dp_group: typing.Optional[torch.distributed.distributed_c10d.ProcessGroup] = None, shard_to_working_param: typing.Optional[typing.Dict] = {}, padding_map: typing.Optional[typing.Dict] = defaultdict(<class 'int'>, {}), is_zero: typing.Optional[bool] = False)
Parameters
  • tp_group (dist.ProcessGroup) -- Tensor Parallel process group
  • dp_group (dist.ProcessGroup) -- ZeRO 2 process group
  • shard_to_working_param (Dict) -- ZeRO 2 feeds the optimizer a sharded param view as grads are sharded. This maps from id(view) to working params used in forward & backward.
  • padding_map (Dict) -- Padding size of each param from ZeRO's param store. Required if ZeRO is used.
  • is_zero (bool) -- Whether to use ZeRO 2.
Description
Setup process groups for TP and ZeRO 2.
function
 

step

(closure = None)
Parameters
  • closure (callable, optional) -- A closure that reevaluates the model and returns the loss.
Description
Performs a single optimization step.
function
 

to_master_shape

(data, padding)
Description
Pad to master (optimizer) param shape
class
 

colossalai.nn.DistributedCAME

(params, lr = None, eps = (1e-30, 1e-16), clip_threshold = 1.0, betas = (0.9, 0.999, 0.9999), weight_decay = 0.0)
Parameters
  • params (iterable) -- iterable of parameters to optimize or dicts defining parameter groups
  • lr (float, optional) -- external learning rate (default: None)
  • eps (tuple[float, float]) -- regularization constants for square gradient and instability respectively (default: (1e-30, 1e-16))
  • clip_threshold (float) -- threshold of root-mean-square of final gradient update (default: 1.0)
  • betas (tuple[float, float, float]) -- coefficient used for computing running averages of
  • update, square gradient and instability (default -- (0.9, 0.999, 0.9999)))
  • weight_decay (float, optional) -- weight decay (L2 penalty) (default: 0)
Description
Implements CAME algorithm. This implementation is based on: `CAME: Confidence-guided Adaptive Memory Efficient Optimization`
function
 

setup_distributed

(tp_group: ProcessGroup = None, dp_group: ProcessGroup = None, shard_to_working_param: typing.Dict = {}, padding_map = None, use_zero: bool = True)
Parameters

tp_group -- The devices group for tensor parallel; dp_group -- The devices group for data parallel;

  • shard_to_working_param (Dict) -- ZeRO 2 feeds the optimizer a sharded param view as grads are sharded. This maps from id(view) to working params used in forward & backward. padding_map -- Interface placeholder use_zero -- Whether or not to use zero;
Description

Inject features to the Optimizer

function
 

step

(closure = None)
Parameters
  • closure (callable, optional) -- A closure that reevaluates the model and returns the loss.
Description
Performs a single optimization step.