跳到主要内容

Booster Checkpoint

作者: Hongxin Liu

前置教程:

引言

我们在之前的教程中介绍了 Booster API。在本教程中,我们将介绍如何使用 booster 保存和加载 checkpoint。

模型 Checkpoint

function
 

colossalai.booster.Booster.save_model

(model: typing.Union[torch.nn.modules.module.Module, colossalai.interface.model.ModelWrapper], checkpoint: str, shard: bool = False, gather_dtensor: bool = True, prefix: typing.Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False)
Parameters
  • model (nn.Module or ModelWrapper) -- A model boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It is a file path if shard=False. Otherwise, it is a directory path.
  • shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
  • gather_dtensor (bool, optional) -- whether to gather the distributed tensor to the first device. Default: True.
  • prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
  • size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when shard=True. Defaults to 1024.
  • use_safetensors (bool, optional) -- whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
Description
Save model to checkpoint.

模型在保存前必须被 colossalai.booster.Booster 封装。 checkpoint 是要保存的 checkpoint 的路径。 如果 shard=False,它就是文件。 否则, 它就是文件夹。如果 shard=True,checkpoint 将以分片方式保存,在 checkpoint 太大而无法保存在单个文件中时会很实用。我们的分片 checkpoint 格式与 huggingface/transformers 兼容,所以用户可以使用huggingface的from_pretrained方法从分片checkpoint加载模型。

function
 

colossalai.booster.Booster.load_model

(model: typing.Union[torch.nn.modules.module.Module, colossalai.interface.model.ModelWrapper], checkpoint: str, strict: bool = True)
Parameters
  • model (nn.Module or ModelWrapper) -- A model boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
  • strict (bool, optional) -- whether to strictly enforce that the keys in :attr:state_dict match the keys returned by this module's [~torch.nn.Module.state_dict] function. Defaults to True.
Description
Load model from checkpoint.

模型在加载前必须被 colossalai.booster.Booster 封装。它会自动检测 checkpoint 格式,并以相应的方式加载。

如果您想从Huggingface加载预训练好的模型,但模型太大以至于无法在单个设备上通过“from_pretrained”直接加载,推荐的方法是将预训练的模型权重下载到本地,并在封装模型后使用booster.load直接从本地路径加载。为了避免内存不足,模型需要在Lazy Initialization的环境下初始化。以下是示例伪代码:

from colossalai.lazy import LazyInitContext
from huggingface_hub import snapshot_download
...

# Initialize model under lazy init context
init_ctx = LazyInitContext(default_device=get_current_device)
with init_ctx:
model = LlamaForCausalLM(config)

...

# Wrap the model through Booster.boost
model, optimizer, _, _, _ = booster.boost(model, optimizer)

# download huggingface pretrained model to local directory.
model_dir = snapshot_download(repo_id="lysandre/arxiv-nlp")

# load model using booster.load
booster.load(model, model_dir)
...

优化器 Checkpoint

function
 

colossalai.booster.Booster.save_optimizer

(optimizer: Optimizer, checkpoint: str, shard: bool = False, gather_dtensor: bool = True, prefix: typing.Optional[str] = None, size_per_shard: int = 1024)
Parameters
  • optimizer (Optimizer) -- An optimizer boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It is a file path if shard=False. Otherwise, it is a directory path.
  • shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
  • gather_dtensor (bool) -- whether to gather the distributed tensor to the first device. Default: True.
  • prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
  • size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when shard=True. Defaults to 1024.
Description

Save optimizer to checkpoint.

优化器在保存前必须被 colossalai.booster.Booster 封装。

function
 

colossalai.booster.Booster.load_optimizer

(optimizer: Optimizer, checkpoint: str)
Parameters
  • optimizer (Optimizer) -- An optimizer boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
  • prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
  • size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when shard=True. Defaults to 1024.
Description
Load optimizer from checkpoint.

优化器在加载前必须被 colossalai.booster.Booster 封装。

学习率调度器 Checkpoint

function
 

colossalai.booster.Booster.save_lr_scheduler

(lr_scheduler: _LRScheduler, checkpoint: str)
Parameters
  • lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local file path.
Description
Save lr scheduler to checkpoint.

学习率调度器在保存前必须被 colossalai.booster.Booster 封装。 checkpoint 是 checkpoint 文件的本地路径.

function
 

colossalai.booster.Booster.load_lr_scheduler

(lr_scheduler: _LRScheduler, checkpoint: str)
Parameters
  • lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local file path.
Description
Load lr scheduler from checkpoint.

学习率调度器在加载前必须被 colossalai.booster.Booster 封装。 checkpoint 是 checkpoint 文件的本地路径.

Checkpoint 设计

有关 Checkpoint 设计的更多详细信息,请参见我们的讨论 A Unified Checkpoint System Design.