Booster Checkpoint
Author: Hongxin Liu
Prerequisite:
Introduction
We've introduced the Booster API in the previous tutorial. In this tutorial, we will introduce how to save and load checkpoints using booster.
Model Checkpoint
function
colossalai.booster.Booster.save_model
- model (nn.Module or ModelWrapper) -- A model boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path.
It is a file path if
shard=False
. Otherwise, it is a directory path. - shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
- gather_dtensor (bool, optional) -- whether to gather the distributed tensor to the first device. Default: True.
- prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024. - use_safetensors (bool, optional) -- whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
Model must be boosted by colossalai.booster.Booster
before saving. checkpoint
is the path to saved checkpoint. It can be a file, if shard=False
. Otherwise, it should be a directory. If shard=True
, the checkpoint will be saved in a sharded way. This is useful when the checkpoint is too large to be saved in a single file. Our sharded checkpoint format is compatible with huggingface/transformers, so you can use huggingface from_pretrained
method to load model from our sharded checkpoint.
function
colossalai.booster.Booster.load_model
- model (nn.Module or ModelWrapper) -- A model boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
- strict (bool, optional) -- whether to strictly enforce that the keys
in :attr:state_dict match the keys returned by this module's
[
~torch.nn.Module.state_dict
] function. Defaults to True.
Model must be boosted by colossalai.booster.Booster
before loading. It will detect the checkpoint format automatically, and load in corresponding way.
If you want to load a pretrained model from Huggingface while the model is too large to be directly loaded through from_pretrained
on a single device, a recommended way is to download the pretrained weights to a local directory, and use booster.load
to load from that directory after boosting the model. Also, the model should be initialized under lazy initialization context to avoid OOM. Here is an example pseudocode:
from colossalai.lazy import LazyInitContext
from huggingface_hub import snapshot_download
...
# Initialize model under lazy init context
init_ctx = LazyInitContext(default_device=get_current_device)
with init_ctx:
model = LlamaForCausalLM(config)
...
# Wrap the model through Booster.boost
model, optimizer, _, _, _ = booster.boost(model, optimizer)
# download huggingface pretrained model to local directory.
model_dir = snapshot_download(repo_id="lysandre/arxiv-nlp")
# load model using booster.load
booster.load(model, model_dir)
...
Optimizer Checkpoint
function
colossalai.booster.Booster.save_optimizer
- optimizer (Optimizer) -- An optimizer boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path.
It is a file path if
shard=False
. Otherwise, it is a directory path. - shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
- gather_dtensor (bool) -- whether to gather the distributed tensor to the first device. Default: True.
- prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024.
Save optimizer to checkpoint.
Optimizer must be boosted by colossalai.booster.Booster
before saving.
function
colossalai.booster.Booster.load_optimizer
- optimizer (Optimizer) -- An optimizer boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
- prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
- size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when
shard=True
. Defaults to 1024.
Optimizer must be boosted by colossalai.booster.Booster
before loading.
LR Scheduler Checkpoint
function
colossalai.booster.Booster.save_lr_scheduler
- lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local file path.
LR scheduler must be boosted by colossalai.booster.Booster
before saving. checkpoint
is the local path to checkpoint file.
function
colossalai.booster.Booster.load_lr_scheduler
- lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
- checkpoint (str) -- Path to the checkpoint. It must be a local file path.
LR scheduler must be boosted by colossalai.booster.Booster
before loading. checkpoint
is the local path to checkpoint file.
Checkpoint design
More details about checkpoint design can be found in our discussion A Unified Checkpoint System Design.