Booster API
Author: Mingyan Jiang
Prerequisite:
Example Code
Introduction
In our new design, colossalai.booster
replaces the role of colossalai.initialize
to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also calling colossalai.booster
is the standard procedure before you run into your training loops. In the sections below, I will cover how colossalai.booster
works and what we should take note of.
Plugin
Plugin is an important component that manages parallel configuration (eg: The gemini plugin encapsulates the gemini acceleration solution). Currently supported plugins are as follows:
GeminiPlugin: This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management.
TorchDDPPlugin: This plugin wraps the DDP acceleration solution, it implements data parallelism at the module level which can run across multiple machines.
LowLevelZeroPlugin: This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.
API of booster
class
colossalai.booster.Booster
- device (str or torch.device) -- The device to run the training. Default: 'cuda'.
- mixed_precision (str or MixedPrecision) -- The mixed precision to run the training. Default: None.
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
'fp16' would use PyTorch AMP while
fp16_apex
would use Nvidia Apex. - plugin (Plugin) -- The plugin to run the training. Default: None.
Booster is a high-level API for training neural networks. It provides a unified interface for training with different precision, accelerator, and plugin.
Examples:
colossalai.launch(...)
plugin = GeminiPlugin(stage=3, ...)
booster = Booster(precision='fp16', plugin=plugin)
model = GPT2()
optimizer = Adam(model.parameters())
dataloader = Dataloader(Dataset)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()
model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
for epoch in range(max_epochs):
for input_ids, attention_mask in dataloader:
outputs = model(input_ids, attention_mask)
loss = criterion(outputs.logits, input_ids)
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
function
backward
- loss (torch.Tensor) -- The loss to be backpropagated.
- optimizer (Optimizer) -- The optimizer to be updated.
function
boost
- model (nn.Module) -- The model to be boosted.
- optimizer (Optimizer) -- The optimizer to be boosted.
- criterion (Callable) -- The criterion to be boosted.
- dataloader (DataLoader) -- The dataloader to be boosted.
- lr_scheduler (LRScheduler) -- The lr_scheduler to be boosted.
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
function
load_lr_scheduler
- lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local file path.
function
load_model
- model (nn.Module) -- A model boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
- strict (bool, optional) -- whether to strictly enforce that the keys
in :attr:state_dict match the keys returned by this module's
[
~torch.nn.Module.state_dict
] function. Defaults to True.
function
load_optimizer
- optimizer (Optimizer) -- An optimizer boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
function
no_sync
- model (nn.Module) -- The model to be disabled gradient synchronization.
contextmanager: Context to disable gradient synchronization.
function
save_lr_scheduler
- lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local file path.
function
save_model
- model (nn.Module) -- A model boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path.
It is a file path if
shard=False
. Otherwise, it is a directory path. - prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
- shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024.
function
save_optimizer
- optimizer (Optimizer) -- An optimizer boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path.
It is a file path if
shard=False
. Otherwise, it is a directory path. - shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024.
Usage
In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call colossalai.booster
to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.
A pseudo-code example is like below:
import torch
from torch.optim import SGD
from torchvision.models import resnet18
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
def train():
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = SGD((model.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
x = torch.randn(4, 3, 224, 224)
x = x.to('cuda')
output = model(x)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
scheduler.step()
save_path = "./model"
booster.save_model(model, save_path, True, True, "", 10, use_safetensors=use_safetensors)
new_model = resnet18()
booster.load_model(new_model, save_path)