Skip to main content

Booster Checkpoint

Author: Hongxin Liu

Prerequisite:

Introduction

We've introduced the Booster API in the previous tutorial. In this tutorial, we will introduce how to save and load checkpoints using booster.

Model Checkpoint

function
 

colossalai.booster.Booster.save_model

(model: typing.Union[torch.nn.modules.module.Module, colossalai.interface.model.ModelWrapper], checkpoint: str, shard: bool = False, gather_dtensor: bool = True, prefix: typing.Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False)
Parameters
  • model (nn.Module or ModelWrapper) -- A model boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It is a file path if shard=False. Otherwise, it is a directory path.
  • shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
  • gather_dtensor (bool, optional) -- whether to gather the distributed tensor to the first device. Default: True.
  • prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
  • size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when shard=True. Defaults to 1024.
  • use_safetensors (bool, optional) -- whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
Description
Save model to checkpoint.

Model must be boosted by colossalai.booster.Booster before saving. checkpoint is the path to saved checkpoint. It can be a file, if shard=False. Otherwise, it should be a directory. If shard=True, the checkpoint will be saved in a sharded way. This is useful when the checkpoint is too large to be saved in a single file. Our sharded checkpoint format is compatible with huggingface/transformers, so you can use huggingface from_pretrained method to load model from our sharded checkpoint.

function
 

colossalai.booster.Booster.load_model

(model: typing.Union[torch.nn.modules.module.Module, colossalai.interface.model.ModelWrapper], checkpoint: str, strict: bool = True)
Parameters
  • model (nn.Module or ModelWrapper) -- A model boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
  • strict (bool, optional) -- whether to strictly enforce that the keys in :attr:state_dict match the keys returned by this module's [~torch.nn.Module.state_dict] function. Defaults to True.
Description
Load model from checkpoint.

Model must be boosted by colossalai.booster.Booster before loading. It will detect the checkpoint format automatically, and load in corresponding way.

If you want to load a pretrained model from Huggingface while the model is too large to be directly loaded through from_pretrained on a single device, a recommended way is to download the pretrained weights to a local directory, and use booster.load to load from that directory after boosting the model. Also, the model should be initialized under lazy initialization context to avoid OOM. Here is an example pseudocode:

from colossalai.lazy import LazyInitContext
from huggingface_hub import snapshot_download
...

# Initialize model under lazy init context
init_ctx = LazyInitContext(default_device=get_current_device)
with init_ctx:
model = LlamaForCausalLM(config)

...

# Wrap the model through Booster.boost
model, optimizer, _, _, _ = booster.boost(model, optimizer)

# download huggingface pretrained model to local directory.
model_dir = snapshot_download(repo_id="lysandre/arxiv-nlp")

# load model using booster.load
booster.load(model, model_dir)
...

Optimizer Checkpoint

function
 

colossalai.booster.Booster.save_optimizer

(optimizer: Optimizer, checkpoint: str, shard: bool = False, gather_dtensor: bool = True, prefix: typing.Optional[str] = None, size_per_shard: int = 1024)
Parameters
  • optimizer (Optimizer) -- An optimizer boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It is a file path if shard=False. Otherwise, it is a directory path.
  • shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
  • gather_dtensor (bool) -- whether to gather the distributed tensor to the first device. Default: True.
  • prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
  • size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when shard=True. Defaults to 1024.
Description

Save optimizer to checkpoint.

Optimizer must be boosted by colossalai.booster.Booster before saving.

function
 

colossalai.booster.Booster.load_optimizer

(optimizer: Optimizer, checkpoint: str)
Parameters
  • optimizer (Optimizer) -- An optimizer boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
  • prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
  • size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when shard=True. Defaults to 1024.
Description
Load optimizer from checkpoint.

Optimizer must be boosted by colossalai.booster.Booster before loading.

LR Scheduler Checkpoint

function
 

colossalai.booster.Booster.save_lr_scheduler

(lr_scheduler: _LRScheduler, checkpoint: str)
Parameters
  • lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local file path.
Description
Save lr scheduler to checkpoint.

LR scheduler must be boosted by colossalai.booster.Booster before saving. checkpoint is the local path to checkpoint file.

function
 

colossalai.booster.Booster.load_lr_scheduler

(lr_scheduler: _LRScheduler, checkpoint: str)
Parameters
  • lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local file path.
Description
Load lr scheduler from checkpoint.

LR scheduler must be boosted by colossalai.booster.Booster before loading. checkpoint is the local path to checkpoint file.

Checkpoint design

More details about checkpoint design can be found in our discussion A Unified Checkpoint System Design.