Skip to main content

Booster API

Author: Mingyan Jiang

Prerequisite:

Example Code

Introduction

In our new design, colossalai.booster replaces the role of colossalai.initialize to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also calling colossalai.booster is the standard procedure before you run into your training loops. In the sections below, I will cover how colossalai.booster works and what we should take note of.

Plugin

Plugin is an important component that manages parallel configuration (eg: The gemini plugin encapsulates the gemini acceleration solution). Currently supported plugins are as follows:

GeminiPlugin: This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management.

TorchDDPPlugin: This plugin wraps the DDP acceleration solution, it implements data parallelism at the module level which can run across multiple machines.

LowLevelZeroPlugin: This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.

API of booster

class
 

colossalai.booster.Booster

(device: str = 'cuda', mixed_precision: typing.Union[colossalai.booster.mixed_precision.mixed_precision_base.MixedPrecision, str] = None, plugin: typing.Optional[colossalai.booster.plugin.plugin_base.Plugin] = None)
Parameters
  • device (str or torch.device) -- The device to run the training. Default: 'cuda'.
  • mixed_precision (str or MixedPrecision) -- The mixed precision to run the training. Default: None. If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'. 'fp16' would use PyTorch AMP while fp16_apex would use Nvidia Apex.
  • plugin (Plugin) -- The plugin to run the training. Default: None.
Description

Booster is a high-level API for training neural networks. It provides a unified interface for training with different precision, accelerator, and plugin.

Example

Examples:

colossalai.launch(...)
plugin = GeminiPlugin(stage=3, ...)
booster = Booster(precision='fp16', plugin=plugin)

model = GPT2()
optimizer = Adam(model.parameters())
dataloader = Dataloader(Dataset)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()

model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)

for epoch in range(max_epochs):
    for input_ids, attention_mask in dataloader:
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs.logits, input_ids)
        booster.backward(loss, optimizer)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
function
 

backward

(loss: Tensor, optimizer: Optimizer)
Parameters
  • loss (torch.Tensor) -- The loss to be backpropagated.
  • optimizer (Optimizer) -- The optimizer to be updated.
Description
Backward pass.
function
 

boost

(model: Module, optimizer: Optimizer, criterion: typing.Callable = None, dataloader: DataLoader = None, lr_scheduler: _LRScheduler = None)
Parameters
  • model (nn.Module) -- The model to be boosted.
  • optimizer (Optimizer) -- The optimizer to be boosted.
  • criterion (Callable) -- The criterion to be boosted.
  • dataloader (DataLoader) -- The dataloader to be boosted.
  • lr_scheduler (LRScheduler) -- The lr_scheduler to be boosted.
Description

Boost the model, optimizer, criterion, lr_scheduler, and dataloader.

function
 

load_lr_scheduler

(lr_scheduler: _LRScheduler, checkpoint: str)
Parameters
  • lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local file path.
Description
Load lr scheduler from checkpoint.
function
 

load_model

(model: Module, checkpoint: str, strict: bool = True)
Parameters
  • model (nn.Module) -- A model boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
  • strict (bool, optional) -- whether to strictly enforce that the keys in :attr:state_dict match the keys returned by this module's [~torch.nn.Module.state_dict] function. Defaults to True.
Description
Load model from checkpoint.
function
 

load_optimizer

(optimizer: Optimizer, checkpoint: str)
Parameters
  • optimizer (Optimizer) -- An optimizer boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
Description
Load optimizer from checkpoint.
function
 

no_sync

(model: Module)
Parameters
  • model (nn.Module) -- The model to be disabled gradient synchronization.
Returns

contextmanager: Context to disable gradient synchronization.

Description
Context manager to disable gradient synchronization across DP process groups.
function
 

save_lr_scheduler

(lr_scheduler: _LRScheduler, checkpoint: str)
Parameters
  • lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local file path.
Description
Save lr scheduler to checkpoint.
function
 

save_model

(model: Module, checkpoint: str, prefix: str = None, shard: bool = False, size_per_shard: int = 1024)
Parameters
  • model (nn.Module) -- A model boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It is a file path if shard=False. Otherwise, it is a directory path.
  • prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
  • shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
  • size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when shard=True. Defaults to 1024.
Description
Save model to checkpoint.
function
 

save_optimizer

(optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024)
Parameters
  • optimizer (Optimizer) -- An optimizer boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It is a file path if shard=False. Otherwise, it is a directory path.
  • shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
  • size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when shard=True. Defaults to 1024.
Description
Save optimizer to checkpoint. Warning: Saving sharded optimizer checkpoint is not supported yet.

Usage

In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call colossalai.booster to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.

A pseudo-code example is like below:

import torch
from torch.optim import SGD
from torchvision.models import resnet18

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin

def train():
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = SGD((model.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)

x = torch.randn(4, 3, 224, 224)
x = x.to('cuda')
output = model(x)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
scheduler.step()

save_path = "./model"
booster.save_model(model, save_path, True, True, "", 10, use_safetensors=use_safetensors)

new_model = resnet18()
booster.load_model(new_model, save_path)

more design details