Skip to main content

Zero Redundancy Optimizer with chunk-based memory management

Author: Hongxin Liu, Jiarui Fang, Zijian Ye


Example Code

Related Paper


The Zero Redundancy Optimizer (ZeRO) removes the memory redundancies across data-parallel processes by partitioning three model states (optimizer states, gradients, and parameters) instead of replicating them. By doing so, memory efficiency is boosted drastically compared to classic data parallelism, while the computational granularity and communication efficiency is retained.

  1. Shard Optimizer States: The optimizer states (e.g., for Adam optimizer, 32-bit weights, and the first and second momentum estimates) are partitioned across the processes, so that each process updates only its partition.
  1. Shard Gradient: After reduction inside data parallel process group, gradient tensors are also partitioned such that each process only stores the gradients corresponding to its partition of the optimizer states. Note, Colossal converts gradient into fp32 format to participate in parameter updating.

  2. Shard Parameter: The 16-bit model parameters are partitioned across the processes of a data parallel group.

  3. Gemini: Dynamic heterogeneous memory space manager for parameters, gradients and optimizer states.

Besides, this article will introduce the Zero Redundancy Optimizer with chunk-based memory management.

When using ZeRO, we distributed the model by sharding the parameters. The advantage of this method is that the memory of each node is load balanced. But this approach has two significant disadvantages. First, during communication, a temporary memory buffer needs to be allocated and released afterwards, leading to the memory fragmentation problem. Secondly, using tensor as the granularity for communication will cause the network bandwidth underutilized. Generally, the longer the transmitted message length, the higher the bandwidth utilization.

Using the Chunk mechanism introduced in ColossalAI v0.1.8, we can improve the efficiency of ZeRO. We store a continuous set of parameters in initialization order into a Chunk (a chunk is a continuous memory space), and each Chunk has the same size. Organizing memory in chunks can lead to efficient use of network bandwidth between PCI-e and GPU-GPU, reduce the number of communications, and avoid potential memory fragmentation.

Before v0.1.8, ZeRO had a high communication cost for parameter communications. If a parameter was used multiple times in several consecutive operators, there will be repeated communications operations, and the efficiency was highly damaged. This situation is very common when using the Gradient Checkpoint technique, and the parameter will recompute the forward propagation during backward propagation.

Taking GPT as an example, its Checkpoint will be applied to each GPT Block, and each GPT Block contains a Self-Attention layer and an MLP layer. During the backward pass, the forward of the Self-Attention layer and the MLP layer will be computed in turn, and then the backward of the MLP layer and the Self-Attention layer will be computed in turn.

In addition, due to the communication and memory movement of small Tensors, the bandwidth of NVLINK and PCI-E cannot be fully utilized, and each communication and memory movement has the overhead of kernel launch. After using Chunk, multiple small Tensor communication and memory movement can be changed into one large Tensor communication and memory movement, which not only improves bandwidth utilization but also reduces the overhead of kernel launch.

We also provide a lightweight chunk search mechanism to help users automatically find the chunk size with the smallest memory fragmentation.



We will use GeminiDDP to use ZeRO with chunk-based memory management. This is our new torch.Module wrapper which uses ZeRO-DP and Gemini. ZeRO is for parallelism and Gemini is for memory management.

Gemini allows LazyInitContext, which can save memory when initializing large models with multi-GPUs.

If your model has N billion parameters and your GPU memory is M GB, we recommend you use LazyInitContext when 4N >= M. Otherwise, LazyInitContext is optional.

with LazyInitContext(default_device=torch.device('cuda')):
model = gpt2_medium(checkpoint=True)

We've provided Booster API which is user-friendly. We recommend you use Booster API. But if you still want to use low level API, you can read below content of this section.

Wrap the model with GeminiDDP.

model = GeminiDDP(model, hidden_dim=hidden_dim, min_chunk_size_m=min_chunk_size_m)

hidden_dim is the hidden dimension of DNN. Users can provide this argument to speed up searching. If users do not know this argument before training, it is ok. We will use a default value 1024. min_chunk_size_m is a floating point, being the minimum chunk size divided by 2^20 (e.g., if min_chunk_size_m=2.5, then the minimum chunk size should be 2.5*(2^20)).If the aggregate size of parameters is still smaller than the minimum chunk size, all parameters will be compacted into one small chunk.

Initialization of the optimizer.

optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)


outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)

⚠️ Note: Please do not use loss.backward(), the standard way of writing is optimizer.backward(loss).

Train GPT

In this example, we use Hugging Face Transformers. You have to install transformers before running this example. We will take GPT2 Medium as an example here.

For simplicity, we just use randomly generated data here.

First we only need to import GPT2LMHeadModel from Huggingface transformers to define our model, which does not require users to define or modify the model, so that users can use it more conveniently.

Define a GPT model:

class GPTLMModel(nn.Module):

def __init__(self,
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(
if checkpoint:

def forward(self, input_ids, attention_mask):
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]

def gpt2_medium(checkpoint=False):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)

Define our loss function:

class GPTLMLoss(nn.Module):

def __init__(self):
self.loss_fn = nn.CrossEntropyLoss()

def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

Write a function to get random inputs:

def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask

Finally, we define a model which uses Gemini + ZeRO DDP and define our training loop, As we pre-train GPT in this example, we just use a simple language model loss:

from colossalai.nn.optimizer import HybridAdam

from colossalai.booster import Booster
from colossalai.lazy import LazyInitContext
from colossalai.booster.plugin import GeminiPlugin

def main():
args = parse_args()
SEQ_LEN = 1024
VOCAB_SIZE = 50257

# build criterion
criterion = GPTLMLoss()
optimizer = HybridAdam(model.parameters(), lr=0.001)

# build GPT model
with ColoInitContext(default_device=torch.device('cuda')):
model = gpt2_medium(checkpoint=True)

# Gemini + ZeRO DP
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

for n in range(NUM_STEPS):
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
booster.backward(loss, optimizer)


⚠️ Note: If you want to use the Gemini module, please do not use the Gradient Accumulation we mentioned before。 The complete example can be found on Train GPT with Colossal-AI.