Skip to main content

Booster Plugins

Author: Hongxin Liu

Prerequisite:

Introduction

As mentioned in Booster API, we can use booster plugins to customize the parallel training. In this tutorial, we will introduce how to use booster plugins.

We currently provide the following plugins:

  • Low Level Zero Plugin: It wraps the colossalai.zero.low_level.LowLevelZeroOptimizer and can be used to train models with zero-dp. It only supports zero stage-1 and stage-2.
  • Gemini Plugin: It wraps the Gemini which implements Zero-3 with chunk-based and heterogeneous memory management.
  • Torch DDP Plugin: It is a wrapper of torch.nn.parallel.DistributedDataParallel and can be used to train models with data parallelism.
  • Torch FSDP Plugin: It is a wrapper of torch.distributed.fsdp.FullyShardedDataParallel and can be used to train models with zero-dp.

More plugins are coming soon.

Plugins

Low Level Zero Plugin

This plugin implements Zero-1 and Zero-2 (w/wo CPU offload), using reduce and gather to synchronize gradients and weights.

Zero-1 can be regarded as a better substitute of Torch DDP, which is more memory efficient and faster. It can be easily used in hybrid parallelism.

Zero-2 does not support local gradient accumulation. Though you can accumulate gradient if you insist, it cannot reduce communication cost. That is to say, it's not a good idea to use Zero-2 with pipeline parallelism.

class
 

colossalai.booster.plugin.LowLevelZeroPlugin

(stage: int = 1, precision: str = 'fp16', initial_scale: float = 4294967296, min_scale: float = 1, growth_factor: float = 2, backoff_factor: float = 0.5, growth_interval: int = 1000, hysteresis: int = 2, max_scale: float = 4294967296, max_norm: float = 0.0, norm_type: float = 2.0, reduce_bucket_size_in_m: int = 12, communication_dtype: typing.Optional[torch.dtype] = None, overlap_communication: bool = True, cpu_offload: bool = False, verbose: bool = False)
Parameters
  • strage (int, optional) -- ZeRO stage. Defaults to 1.
  • precision (str, optional) -- precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'.
  • initial_scale (float, optional) -- Initial scale used by DynamicGradScaler. Defaults to 2**32.
  • min_scale (float, optional) -- Min scale used by DynamicGradScaler. Defaults to 1.
  • growth_factor (float, optional) -- growth_factor used by DynamicGradScaler. Defaults to 2.
  • backoff_factor (float, optional) -- backoff_factor used by DynamicGradScaler. Defaults to 0.5.
  • growth_interval (float, optional) -- growth_interval used by DynamicGradScaler. Defaults to 1000.
  • hysteresis (float, optional) -- hysteresis used by DynamicGradScaler. Defaults to 2.
  • max_scale (int, optional) -- max_scale used by DynamicGradScaler. Defaults to 2**32.
  • max_norm (float, optional) -- max_norm used for clip_grad_norm. You should notice that you shall not do clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
  • norm_type (float, optional) -- norm_type used for clip_grad_norm.
  • reduce_bucket_size_in_m (int, optional) -- grad reduce bucket size in M. Defaults to 12.
  • communication_dtype (torch.dtype, optional) -- communication dtype. If not specified, the dtype of param will be used. Defaults to None.
  • overlap_communication (bool, optional) -- whether to overlap communication and computation. Defaults to True.
  • cpu_offload (bool, optional) -- whether to offload grad, master weight and optimizer state to cpu. Defaults to False.
  • verbose (bool, optional) -- verbose mode. Debug info including grad overflow will be printed. Defaults to False.
Description

Plugin for low level zero.

Example:

from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin

model, train_dataset, optimizer, criterion = ... plugin = LowLevelZeroPlugin()

train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) booster = Booster(plugin=plugin) model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

We've tested compatibility on some famous models, following models may not be supported:

  • timm.models.convit_base
  • dlrm and deepfm models in torchrec
  • diffusers.VQModel
  • transformers.AlbertModel
  • transformers.AlbertForPreTraining
  • transformers.BertModel
  • transformers.BertForPreTraining
  • transformers.GPT2DoubleHeadsModel

Compatibility problems will be fixed in the future.

⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future.

Gemini Plugin

This plugin implements Zero-3 with chunk-based and heterogeneous memory management. It can train large models without much loss in speed. It also does not support local gradient accumulation. More details can be found in Gemini Doc.

class
 

colossalai.booster.plugin.GeminiPlugin

(device: typing.Optional[torch.device] = None, placement_policy: str = 'cpu', pin_memory: bool = False, force_outputs_fp32: bool = False, strict_ddp_mode: bool = False, search_range_mb: int = 32, hidden_dim: typing.Optional[int] = None, min_chunk_size_mb: float = 32, memstats: typing.Optional[colossalai.zero.gemini.memory_tracer.memory_stats.MemStats] = None, gpu_margin_mem_ratio: float = 0.0, initial_scale: float = 4294967296, min_scale: float = 1, growth_factor: float = 2, backoff_factor: float = 0.5, growth_interval: int = 1000, hysteresis: int = 2, max_scale: float = 4294967296, max_norm: float = 0.0, norm_type: float = 2.0, verbose: bool = False)
Parameters
  • device (torch.device) -- device to place the model.
  • placement_policy (str, optional) -- "cpu", "cuda", "auto". Defaults to "cpu".
  • pin_memory (bool, optional) -- use pin memory on CPU. Defaults to False.
  • force_outputs_fp32 (bool, optional) -- force outputs are fp32. Defaults to False.
  • strict_ddp_mode (bool, optional) -- use strict ddp mode (only use dp without other parallelism). Defaults to False.
  • search_range_mb (int, optional) -- chunk size searching range in MegaByte. Defaults to 32.
  • hidden_dim (int, optional) -- the hidden dimension of DNN. Users can provide this argument to speed up searching. If users do not know this argument before training, it is ok. We will use a default value 1024.
  • min_chunk_size_mb (float, optional) -- the minimum chunk size in MegaByte. If the aggregate size of parameters is still smaller than the minimum chunk size, all parameters will be compacted into one small chunk.
  • memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. --
  • gpu_margin_mem_ratio (float, optional) -- The ratio of GPU remaining memory (after the first forward-backward) which will be used when using hybrid CPU optimizer. This argument is meaningless when placement_policy of GeminiManager is not "auto". Defaults to 0.0.
  • initial_scale (float, optional) -- Initial scale used by DynamicGradScaler. Defaults to 2**32.
  • min_scale (float, optional) -- Min scale used by DynamicGradScaler. Defaults to 1.
  • growth_factor (float, optional) -- growth_factor used by DynamicGradScaler. Defaults to 2.
  • backoff_factor (float, optional) -- backoff_factor used by DynamicGradScaler. Defaults to 0.5.
  • growth_interval (float, optional) -- growth_interval used by DynamicGradScaler. Defaults to 1000.
  • hysteresis (float, optional) -- hysteresis used by DynamicGradScaler. Defaults to 2.
  • max_scale (int, optional) -- max_scale used by DynamicGradScaler. Defaults to 2**32.
  • max_norm (float, optional) -- max_norm used for clip_grad_norm. You should notice that you shall not do clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
  • norm_type (float, optional) -- norm_type used for clip_grad_norm.
  • verbose (bool, optional) -- verbose mode. Debug info including chunk search result will be printed. Defaults to False.
Description

Plugin for Gemini.

Example:

from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin

model, train_dataset, optimizer, criterion = ... plugin = GeminiPlugin()

train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) booster = Booster(plugin=plugin) model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future.

Torch DDP Plugin

More details can be found in Pytorch Docs.

class
 

colossalai.booster.plugin.TorchDDPPlugin

(broadcast_buffers: bool = True, bucket_cap_mb: int = 25, find_unused_parameters: bool = False, check_reduction: bool = False, gradient_as_bucket_view: bool = False, static_graph: bool = False)
Parameters
  • broadcast_buffers (bool, optional) -- Whether to broadcast buffers in the beginning of training. Defaults to True.
  • bucket_cap_mb (int, optional) -- The bucket size in MB. Defaults to 25.
  • find_unused_parameters (bool, optional) -- Whether to find unused parameters. Defaults to False.
  • check_reduction (bool, optional) -- Whether to check reduction. Defaults to False.
  • gradient_as_bucket_view (bool, optional) -- Whether to use gradient as bucket view. Defaults to False.
  • static_graph (bool, optional) -- Whether to use static graph. Defaults to False.
Description

Plugin for PyTorch DDP.

Example:

from colossalai.booster import Booster from colossalai.booster.plugin import TorchDDPPlugin

model, train_dataset, optimizer, criterion = ... plugin = TorchDDPPlugin()

train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) booster = Booster(plugin=plugin) model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

Torch FSDP Plugin

⚠ This plugin is not available when torch version is lower than 1.12.0.

⚠ This plugin does not support save/load sharded model checkpoint now.

⚠ This plugin does not support optimizer that use multi params group.

More details can be found in Pytorch Docs.

class
 

colossalai.booster.plugin.TorchFSDPPlugin

(process_group: typing.Optional[torch.distributed.distributed_c10d.ProcessGroup] = None, sharding_strategy: typing.Optional[torch.distributed.fsdp.api.ShardingStrategy] = None, cpu_offload: typing.Optional[torch.distributed.fsdp.api.CPUOffload] = None, auto_wrap_policy: typing.Optional[typing.Callable] = None, backward_prefetch: typing.Optional[torch.distributed.fsdp.api.BackwardPrefetch] = None, mixed_precision: typing.Optional[torch.distributed.fsdp.api.MixedPrecision] = None, ignored_modules: typing.Optional[typing.Iterable[torch.nn.modules.module.Module]] = None, param_init_fn: typing.Optional[typing.Callable[[torch.nn.modules.module.Module]], NoneType] = None, sync_module_states: bool = False)
Parameters
  • See https --//pytorch.org/docs/stable/fsdp.html for details.
Description

Plugin for PyTorch FSDP.

Example:

from colossalai.booster import Booster from colossalai.booster.plugin import TorchFSDPPlugin

model, train_dataset, optimizer, criterion = ... plugin = TorchFSDPPlugin()

train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) booster = Booster(plugin=plugin) model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)