Booster API
作者: Mingyan Jiang, Jianghai Chen, Baizhou Zhang
预备知识:
示例代码
简介
在我们的新设计中, colossalai.booster
代替 colossalai.initialize
将特征(例如,模型、优化器、数据加载器)无缝注入到您的训练组件中。 使用 booster API, 您可以更友好地将我们的并行策略整合到待训练模型中. 调用 colossalai.booster
是您进入训练流程前的正常操作。
在下面的章节中,我们将介绍 colossalai.booster
是如何工作的以及使用时我们要注意的细节。
Booster 插件
Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了 gemini 加速方案)。目前支持的插件如下:
HybridParallelPlugin: HybirdParallelPlugin 插件封装了混合并行的加速解决方案。它提供的接口可以在张量并行,流水线并行以及两种数据并行方法(DDP, Zero)间进行任意的组合。
GeminiPlugin: GeminiPlugin 插件封装了 gemini 加速解决方案,即基于块内存管理的 ZeRO 优化方案。
TorchDDPPlugin: TorchDDPPlugin 插件封装了Pytorch的DDP加速方案,实现了模型级别的数据并行,可以跨多机运行。
LowLevelZeroPlugin: LowLevelZeroPlugin 插件封装了零冗余优化器的 1/2 阶段。阶段 1:切分优化器参数,分发到各并发进程或并发 GPU 上。阶段 2:切分优化器参数及梯度,分发到各并发进程或并发 GPU 上。
TorchFSDPPlugin: TorchFSDPPlugin封装了 Pytorch的FSDP加速方案,可以用于零冗余优化器数据并行(ZeroDP)的训练。
若想了解更多关于插件的用法细节,请参考Booster 插件章节。
有一些插件支持懒惰初始化,它能节省初始化大模型时的内存占用。详情请参考懒惰初始化。
Booster 接口
class
colossalai.booster.Booster
- device (str or torch.device) -- The device to run the training. Default: None. If plugin is not used or plugin doesn't control the device, this argument will be set as training device ('cuda' will be used if argument is None).
- mixed_precision (str or MixedPrecision) -- The mixed precision to run the training. Default: None.
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
'fp16' would use PyTorch AMP while
fp16_apex
would use Nvidia Apex. - plugin (Plugin) -- The plugin to run the training. Default: None.
Booster is a high-level API for training neural networks. It provides a unified interface for training with different precision, accelerator, and plugin.
# Following is pseudocode
colossalai.launch(...)
plugin = GeminiPlugin(...)
booster = Booster(precision='fp16', plugin=plugin)
model = GPT2()
optimizer = HybridAdam(model.parameters())
dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()
model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)
for epoch in range(max_epochs):
for input_ids, attention_mask in dataloader:
outputs = model(input_ids.cuda(), attention_mask.cuda())
loss = criterion(outputs.logits, input_ids)
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
function
backward
- loss (torch.Tensor) -- The loss for backpropagation.
- optimizer (Optimizer) -- The optimizer to be updated.
function
boost
- model (nn.Module) -- Convert model into a wrapped model for distributive training. The model might be decorated or partitioned by plugin's strategy after execution of this method.
- optimizer (Optimizer, optional) -- Convert optimizer into a wrapped optimizer for distributive training. The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None.
- criterion (Callable, optional) -- The function that calculates loss. Defaults to None.
- dataloader (DataLoader, optional) -- The prepared dataloader for training. Defaults to None.
- lr_scheduler (LRScheduler, optional) -- The learning scheduler for training. Defaults to None.
List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.
Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.
function
execute_pipeline
data_iter(Iterator) -- The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:
- wrap the dataloader to iterator through: iter(dataloader)
- get the next batch from dataloader, and wrap this batch to iterator: iter([batch])
- model (nn.Module) -- The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline.
criterion -- (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here.
- optimizer (Optimizer, optional) -- The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None.
- return_loss (bool, optional) -- Whether to return loss in the dict returned by this method. Defaults to True.
- return_output (bool, optional) -- Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False.
Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}. ret_dict['loss'] is the loss of forward if return_loss is set to True, else None. ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
Execute forward & backward when utilizing pipeline parallel. Return loss or Huggingface style model outputs if needed.
Warning: This function is tailored for the scenario of pipeline parallel. As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward()) when doing pipeline parallel training with booster, which will cause unexpected errors.
function
load_lr_scheduler
- lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local file path.
function
load_model
- model (nn.Module or ModelWrapper) -- A model boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
- strict (bool, optional) -- whether to strictly enforce that the keys
in :attr:state_dict match the keys returned by this module's
[
~torch.nn.Module.state_dict
] function. Defaults to True.
function
load_optimizer
- optimizer (Optimizer) -- An optimizer boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
- prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024.
function
no_sync
- model (nn.Module) -- The model to be disabled gradient synchronization, for DDP
- optimizer (OptimizerWrapper) -- The optimizer to be disabled gradient synchronization, for ZeRO1-1
contextmanager: Context to disable gradient synchronization.
function
save_lr_scheduler
- lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local file path.
function
save_model
- model (nn.Module or ModelWrapper) -- A model boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path.
It is a file path if
shard=False
. Otherwise, it is a directory path. - shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
- gather_dtensor (bool, optional) -- whether to gather the distributed tensor to the first device. Default: True.
- prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024. - use_safetensors (bool, optional) -- whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
function
save_optimizer
- optimizer (Optimizer) -- An optimizer boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path.
It is a file path if
shard=False
. Otherwise, it is a directory path. - shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
- gather_dtensor (bool) -- whether to gather the distributed tensor to the first device. Default: True.
- prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024.
Save optimizer to checkpoint.
使用方法及示例
在使用 colossalai 训练时,首先需要在训练脚本的开头启动分布式环境,并创建需要使用的模型、优化器、损失函数、数据加载器等对象。之后,调用booster.boost
将特征注入到这些对象中,您就可以使用我们的 booster API 去进行您接下来的训练流程。
以下是一个伪代码示例,将展示如何使用我们的 booster API 进行模型训练:
import torch
from torch.optim import SGD
from torchvision.models import resnet18
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
def train():
# launch colossalai
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
# create plugin and objects for training
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = SGD((model.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
# use booster.boost to wrap the training objects
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
# do training as normal, except that the backward should be called by booster
x = torch.randn(4, 3, 224, 224)
x = x.to('cuda')
output = model(x)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# checkpointing using booster api
save_path = "./model"
booster.save_model(model, save_path, shard=True, size_per_shard=10, use_safetensors=True)
new_model = resnet18()
booster.load_model(new_model, save_path)
更多的Booster设计细节请参考这一页面