Booster API
Author: Mingyan Jiang, Jianghai Chen, Baizhou Zhang
Prerequisite:
Example Code
Introduction
In our new design, colossalai.booster
replaces the role of colossalai.initialize
to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also, calling colossalai.booster
is the standard procedure before you run into your training loops. In the sections below, we will cover how colossalai.booster
works and what we should take note of.
Plugin
Plugin is an important component that manages parallel configuration (eg: The gemini plugin encapsulates the gemini acceleration solution). Currently supported plugins are as follows:
HybridParallelPlugin: This plugin wraps the hybrid parallel training acceleration solution. It provides an interface for any combination of tensor parallel, pipeline parallel and data parallel strategies including DDP and ZeRO.
GeminiPlugin: This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management.
TorchDDPPlugin: This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallel at the module level which can run across multiple machines.
LowLevelZeroPlugin: This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.
TorchFSDPPlugin: This plugin wraps the FSDP acceleration solution of Pytorch and can be used to train models with zero-dp.
More details about usages of each plugin can be found in chapter Booster Plugins.
Some plugins support lazy initialization, which can be used to save memory when initializing large models. For more details, please see Lazy Initialization.
API of booster
class
colossalai.booster.Booster
- device (str or torch.device) -- The device to run the training. Default: None. If plugin is not used or plugin doesn't control the device, this argument will be set as training device ('cuda' will be used if argument is None).
- mixed_precision (str or MixedPrecision) -- The mixed precision to run the training. Default: None.
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
'fp16' would use PyTorch AMP while
fp16_apex
would use Nvidia Apex. - plugin (Plugin) -- The plugin to run the training. Default: None.
Booster is a high-level API for training neural networks. It provides a unified interface for training with different precision, accelerator, and plugin.
# Following is pseudocode
colossalai.launch(...)
plugin = GeminiPlugin(...)
booster = Booster(precision='fp16', plugin=plugin)
model = GPT2()
optimizer = HybridAdam(model.parameters())
dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()
model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)
for epoch in range(max_epochs):
for input_ids, attention_mask in dataloader:
outputs = model(input_ids.cuda(), attention_mask.cuda())
loss = criterion(outputs.logits, input_ids)
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
function
backward
- loss (torch.Tensor) -- The loss for backpropagation.
- optimizer (Optimizer) -- The optimizer to be updated.
function
boost
- model (nn.Module) -- Convert model into a wrapped model for distributive training. The model might be decorated or partitioned by plugin's strategy after execution of this method.
- optimizer (Optimizer, optional) -- Convert optimizer into a wrapped optimizer for distributive training. The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None.
- criterion (Callable, optional) -- The function that calculates loss. Defaults to None.
- dataloader (DataLoader, optional) -- The prepared dataloader for training. Defaults to None.
- lr_scheduler (LRScheduler, optional) -- The learning scheduler for training. Defaults to None.
List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.
Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.
function
enable_lora
- model (nn.Module) -- The model to be appended with LoRA modules.
- pretrained_dir(str, optional) -- The path to the pretrained directory, can be a local directory or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub. When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None. lora_config -- (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None.
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory. Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft.
function
execute_pipeline
data_iter(Iterator) -- The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:
- wrap the dataloader to iterator through: iter(dataloader)
- get the next batch from dataloader, and wrap this batch to iterator: iter([batch])
- model (nn.Module) -- The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline.
criterion -- (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here.
- optimizer (Optimizer, optional) -- The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None.
- return_loss (bool, optional) -- Whether to return loss in the dict returned by this method. Defaults to True.
- return_output (bool, optional) -- Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False.
Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}. ret_dict['loss'] is the loss of forward if return_loss is set to True, else None. ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
Execute forward & backward when utilizing pipeline parallel. Return loss or Huggingface style model outputs if needed.
Warning: This function is tailored for the scenario of pipeline parallel. As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward()) when doing pipeline parallel training with booster, which will cause unexpected errors.
function
load_lr_scheduler
- lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local file path.
function
load_model
- model (nn.Module or ModelWrapper) -- A model boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
- strict (bool, optional) -- whether to strictly enforce that the keys
in :attr:state_dict match the keys returned by this module's
[
~torch.nn.Module.state_dict
] function. Defaults to True.
function
load_optimizer
- optimizer (Optimizer) -- An optimizer boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
- prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024.
function
no_sync
- model (nn.Module) -- The model to be disabled gradient synchronization, for DDP
- optimizer (OptimizerWrapper) -- The optimizer to be disabled gradient synchronization, for ZeRO1-1
contextmanager: Context to disable gradient synchronization.
function
save_lora_as_pretrained
- model (Union[nn.Module, ModelWrapper]) -- A model boosted by Booster.
- checkpoint (str) -- Path to the checkpoint directory. It must be a local path.
- use_safetensors (bool, optional) -- Whether to use safe tensors when saving. Defaults to False.
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
function
save_lr_scheduler
- lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local file path.
function
save_model
- model (nn.Module or ModelWrapper) -- A model boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path.
It is a file path if
shard=False
. Otherwise, it is a directory path. - shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
- gather_dtensor (bool, optional) -- whether to gather the distributed tensor to the first device. Default: True.
- prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024. - use_safetensors (bool, optional) -- whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
- use_async (bool, optional) -- whether to save the state_dict of model asynchronously. Default: False.
function
save_optimizer
- optimizer (Optimizer) -- An optimizer boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path.
It is a file path if
shard=False
. Otherwise, it is a directory path. - shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
- gather_dtensor (bool) -- whether to gather the distributed tensor to the first device. Default: True.
- prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024.
Save optimizer to checkpoint.
Usage
In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call booster.boost
to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.
A pseudo-code example is like below:
import torch
from torch.optim import SGD
from torchvision.models import resnet18
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
def train():
# launch colossalai
colossalai.launch(rank=rank, world_size=world_size, port=port, host='localhost')
# create plugin and objects for training
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = SGD((model.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
# use booster.boost to wrap the training objects
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
# do training as normal, except that the backward should be called by booster
x = torch.randn(4, 3, 224, 224)
x = x.to('cuda')
output = model(x)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# checkpointing using booster api
save_path = "./model"
booster.save_model(model, save_path, shard=True, size_per_shard=10, use_safetensors=True)
new_model = resnet18()
booster.load_model(new_model, save_path)
For more design details please see this page.