Booster API
作者: Mingyan Jiang, Jianghai Chen, Baizhou Zhang
预备知识:
示例代码
简介
在我们的新设计中, colossalai.booster
代替 colossalai.initialize
将特征(例如,模型、优化器、数据加载器)无缝注入到您的训练组件中。 使用 booster API, 您可以更友好地将我们的并行策略整合到待训练模型中. 调用 colossalai.booster
是您进入训练流程前的正常操作。
在下面的章节中,我们将介绍 colossalai.booster
是如何工作的以及使用时我们要注意的细节。
Booster 插件
Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了 gemini 加速方案)。目前支持的插件如下:
HybridParallelPlugin: HybridParallelPlugin 插件封装了混合并行的加速解决方案。它提供的接口可以在张量并行,流水线并行以及两种数据并行方法(DDP, Zero)间进行任意的组合。
GeminiPlugin: GeminiPlugin 插件封装了 gemini 加速解决方案,即基于块内存管理的 ZeRO 优化方案。
TorchDDPPlugin: TorchDDPPlugin 插件封装了Pytorch的DDP加速方案,实现了模型级别的数据并行,可以跨多机运行。
LowLevelZeroPlugin: LowLevelZeroPlugin 插件封装了零冗余优化器的 1/2 阶段。阶段 1:切分优化器参数,分发到各并发进程或并发 GPU 上。阶段 2:切分优化器参数及梯度,分发到各并发进程或并发 GPU 上。
TorchFSDPPlugin: TorchFSDPPlugin封装了 Pytorch的FSDP加速方案,可以用于零冗余优化器数据并行(ZeroDP)的训练。
若想了解更多关于插件的用法细节,请参考Booster 插件章节。
有一些插件支持懒惰初始化,它能节省初始化大模型时的内存占用。详情请参考懒惰初始化。
Booster 接口
class
colossalai.booster.Booster
- device (str or torch.device) -- The device to run the training. Default: None. If plugin is not used or plugin doesn't control the device, this argument will be set as training device ('cuda' will be used if argument is None).
- mixed_precision (str or MixedPrecision) -- The mixed precision to run the training. Default: None.
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
'fp16' would use PyTorch AMP while
fp16_apex
would use Nvidia Apex. - plugin (Plugin) -- The plugin to run the training. Default: None.
Booster is a high-level API for training neural networks. It provides a unified interface for training with different precision, accelerator, and plugin.
# Following is pseudocode
colossalai.launch(...)
plugin = GeminiPlugin(...)
booster = Booster(precision='fp16', plugin=plugin)
model = GPT2()
optimizer = HybridAdam(model.parameters())
dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()
model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)
for epoch in range(max_epochs):
for input_ids, attention_mask in dataloader:
outputs = model(input_ids.cuda(), attention_mask.cuda())
loss = criterion(outputs.logits, input_ids)
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
function
backward
- loss (torch.Tensor) -- The loss for backpropagation.
- optimizer (Optimizer) -- The optimizer to be updated.
function
boost
- model (nn.Module) -- Convert model into a wrapped model for distributive training. The model might be decorated or partitioned by plugin's strategy after execution of this method.
- optimizer (Optimizer, optional) -- Convert optimizer into a wrapped optimizer for distributive training. The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None.
- criterion (Callable, optional) -- The function that calculates loss. Defaults to None.
- dataloader (DataLoader, optional) -- The prepared dataloader for training. Defaults to None.
- lr_scheduler (LRScheduler, optional) -- The learning scheduler for training. Defaults to None.
List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.
Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.
function
enable_lora
- model (nn.Module) -- The model to be appended with LoRA modules.
- pretrained_dir(str, optional) -- The path to the pretrained directory, can be a local directory or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub. When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None. lora_config -- (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None.
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory. Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft.
function
execute_pipeline
data_iter(Iterator) -- The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:
- wrap the dataloader to iterator through: iter(dataloader)
- get the next batch from dataloader, and wrap this batch to iterator: iter([batch])
- model (nn.Module) -- The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline.
criterion -- (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here.
- optimizer (Optimizer, optional) -- The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None.
- return_loss (bool, optional) -- Whether to return loss in the dict returned by this method. Defaults to True.
- return_output (bool, optional) -- Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False.
Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}. ret_dict['loss'] is the loss of forward if return_loss is set to True, else None. ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
Execute forward & backward when utilizing pipeline parallel. Return loss or Huggingface style model outputs if needed.
Warning: This function is tailored for the scenario of pipeline parallel. As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward()) when doing pipeline parallel training with booster, which will cause unexpected errors.
function
load_lr_scheduler
- lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local file path.
function
load_model
- model (nn.Module or ModelWrapper) -- A model boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
- strict (bool, optional) -- whether to strictly enforce that the keys
in :attr:state_dict match the keys returned by this module's
[
~torch.nn.Module.state_dict
] function. Defaults to True.
function
load_optimizer
- optimizer (Optimizer) -- An optimizer boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
- prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024.
function
no_sync
- model (nn.Module) -- The model to be disabled gradient synchronization, for DDP
- optimizer (OptimizerWrapper) -- The optimizer to be disabled gradient synchronization, for ZeRO1-1
contextmanager: Context to disable gradient synchronization.
function
save_lora_as_pretrained
- model (Union[nn.Module, ModelWrapper]) -- A model boosted by Booster.
- checkpoint (str) -- Path to the checkpoint directory. It must be a local path.
- use_safetensors (bool, optional) -- Whether to use safe tensors when saving. Defaults to False.
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
function
save_lr_scheduler
- lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local file path.
function
save_model
- model (nn.Module or ModelWrapper) -- A model boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path.
It is a file path if
shard=False
. Otherwise, it is a directory path. - shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
- gather_dtensor (bool, optional) -- whether to gather the distributed tensor to the first device. Default: True.
- prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024. - use_safetensors (bool, optional) -- whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
function
save_optimizer
- optimizer (Optimizer) -- An optimizer boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path.
It is a file path if
shard=False
. Otherwise, it is a directory path. - shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
- gather_dtensor (bool) -- whether to gather the distributed tensor to the first device. Default: True.
- prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024.
Save optimizer to checkpoint.
使用方法及示例
在使用 colossalai 训练时,首先需要在训练脚本的开头启动分布式环境,并创建需要使用的模型、优化器、损失函数、数据加载器等对象。之后,调用booster.boost
将特征注入到这些对象中,您就可以使用我们的 booster API 去进行您接下来的训练流程。
以下是一个伪代码示例,将展示如何使用我们的 booster API 进行模型训练:
import torch
from torch.optim import SGD
from torchvision.models import resnet18
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
def train():
# launch colossalai
colossalai.launch(rank=rank, world_size=world_size, port=port, host='localhost')
# create plugin and objects for training
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = SGD((model.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
# use booster.boost to wrap the training objects
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
# do training as normal, except that the backward should be called by booster
x = torch.randn(4, 3, 224, 224)
x = x.to('cuda')
output = model(x)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# checkpointing using booster api
save_path = "./model"
booster.save_model(model, save_path, shard=True, size_per_shard=10, use_safetensors=True)
new_model = resnet18()
booster.load_model(new_model, save_path)
更多的Booster设计细节请参考这一页面