Zero Redundancy Optimizer and Zero Offload
Author: Zhujie, Shenggui Li, Hongxin Liu
Prerequisite:
Example Code
Related Paper
- ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
- ZeRO-Offload: Democratizing Billion-Scale Model Training
- ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning
- PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
Introductionâ
The Zero Redundancy Optimizer (ZeRO) removes the memory redundancies across data-parallel processes by partitioning three model states (optimizer states, gradients, and parameters) instead of replicating them. By doing so, memory efficiency is boosted drastically compared to classic data parallelism, while the computational granularity and communication efficiency is retained.
- Shard Optimizer States: The optimizer states (e.g., for Adam optimizer, 32-bit weights, and the first and second momentum estimates) are partitioned across the processes, so that each process updates only its partition.
Shard Gradient: After reduction inside data parallel process group, gradient tensors are also partitioned such that each process only stores the gradients corresponding to its partition of the optimizer states. Note, Colossal converts gradient into fp32 format to participate in parameter updating.
Shard Parameter: The 16-bit model parameters are partitioned across the processes of a data parallel group.
Gemini: Dynamic heterogeneous memory space manager for paramters, gradients and optimizer states.
When we shard parameter, gradient and optimizer states, and set tensor placement policy to "cpu"
, we can use three figures to illustrate the training process.



For more details about Gemini, click here.
Usageâ
We provide two levels of API to use ZeRO.
- Low-level API: Use
ShardedModel
andShardedOptimizer
directly, and write your own training loop from scratch. - High-level API: Use
Engine
and configure ZeRO in the configuration file. You can useTrainer
or write your own training loop.
We provide some shard strategies to manage the process of sharding your model:
from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
TensorShardStrategy
is a naive implementation that shard each tensor evenly over all ranks. BucketTensorShardStrategy
fattens the tensors belonging to an operator, e.g. nn.Linear, and then shards them evenly over all ranks. It is especially useful when an operator contains bias
since we cannot utilize network bandwidth well if we only gather a bias
tensor (bias
is usually small).
â ī¸ You have to initialize your model with
colossalai.zero.init_ctx.ZeroInitContext
.
Here is a simple example:
shard_strategy = TensorShardStrategy()
with ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=shard_strategy,
shard_param=True):
model = torch.nn.Linear(2, 2)
You can see the exact usage of ZeroInitContext
in API Reference
If you use high-level API, you must configure
shard_strategy
in config file.
Next, we will firstly give you a configuration template to help you configure ZeRO when using high-level API. Then, we will give you an example of using a low-level API.
We now provide
from colossalai.nn.optimizer.HybridAdam
, which is faster thantorch.optim.Adam
. For more details, see API Reference.
Configure ZeRO with high-level APIâ
You can use Engine
and configure ZeRO in the configuration file.
Here is a configuration template:
from colossalai.zero.shard_utils import TensorShardStrategy
zero = dict(
model_config=dict(
shard_strategy=TensorShardStrategy(),
reduce_scatter_bucket_size_mb=25,
fp32_reduce_scatter=False,
tensor_placement_policy="cuda",
gradient_predivide_factor=1.0,
reuse_fp16_shard=False
),
optimizer_config=dict(
gpu_margin_mem_ratio=0.8,
initial_scale=2**5,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32
)
)
model_config
and optimizer_config
are keyword arguments of ShardedModelV2
and ShardedOptimizerV2
respectively. For more details of these arguments, see ShardedModelV2 API Reference and ShardedOptimizerV2 API Reference.
â ī¸ If you use gradient accumulation, make sure
reuse_fp16_shard
isFalse
.
â ī¸ If you set
tensor_placement_policy
to"auto"
, make sure no other processes use CUDA during your training.
You can initialize your model in this way:
import torch
import colossalai
from colossalai.zero.init_ctx import ZeroInitContext
with ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True):
model = torch.nn.Linear(2, 2)
Then you can use Engine
as usual.
The complete example of training GPT with high-level API can be found on GPT example.
Train GPT with low-level APIâ
In this example, we use Hugging Face Transformers
. You have to install transformers
before running this example. We will take GPT2 Medium
as an example here.
This example is intended for showing you how to use ZeRO
. For simplicity, we just use randomly generated data here.
First, we have to import essential libs:
import colossalai
import torch
import torch.nn as nn
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import CPUAdam
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from transformers import GPT2Config, GPT2LMHeadModel
Then we simply wrap Hugging Face Transformers
:
class GPTLMModel(nn.Module):
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers,
n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size))
if checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
def gpt2_medium(checkpoint=False):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
Define our loss function:
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
As we pre-train GPT in this example, we just use a simple language model loss.
Write a function to get random inputs:
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
Finally, we can define our training loop:
def main():
BATCH_SIZE = 8
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
disable_existing_loggers()
colossalai.launch_from_torch(config={})
logger = get_dist_logger()
logger.info(get_mem_info(), ranks=[0])
# build GPT model
shard_strategy = TensorShardStrategy()
with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=True):
model = gpt2_medium(checkpoint=True)
# Set tensor_placement_policy='cpu', which will offload params, grads and os
model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
# build criterion
criterion = GPTLMLoss()
# optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ShardedOptimizerV2(model, optimizer, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
model.train()
for n in range(NUM_STEPS):
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
optimizer.zero_grad()
outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
logger.info(get_mem_info(prefix=f'Forward [{n+1}/{NUM_STEPS}] '), ranks=[0])
optimizer.backward(loss)
logger.info(get_mem_info(prefix=f'Backward [{n+1}/{NUM_STEPS}] '), ranks=[0])
optimizer.step()
logger.info(get_mem_info(prefix=f'Optimizer step [{n+1}/{NUM_STEPS}] '), ranks=[0])
The complete example can be found on ZeRO example.