Skip to main content
Version: v0.1.6

Step By Step: Pretrain GPT-2 On Single GPU with ZeRO

Author: Yuxuan Lou

Example Code

Related Papers


Generative Pre-trained Transformer-2 (GPT-2) is an autoregressive language model created by OpenAI. It uses deep learning to produce human-like text. As the quality of the text generated is very high, GPT-2 is well known and widely used. However, it is hard for researchers and users to pretrain GPT-2 from scratch due to its huge model scale.

Colossal-AI provides a good solution to this: The Zero Redundancy Optimizer (ZeRO). ZeRO removes the memory redundancies across data-parallel processes by partitioning three model states (optimizer states, gradients, and parameters) instead of replicating them. By doing so, memory efficiency is boosted drastically compared to classic data parallelism, while the computational granularity and communication efficiency is retained. Also, Zero enables CPU Offloading: Offload the Optimizer States from GPU to CPU to save GPU memory usage.

Currently, Colossal-AI provide two levels of API to use ZeRO.

  • Low-level API: Use ShardedModel and ShardedOptimizer directly, and write your own training loop from scratch.
  • High-level API: Use Engine and configure ZeRO in the configuration file. You can use Trainer or write your own training loop.

For more details, you can check here.

In this step-by-step tutorial, we will teach you how to build ZeRO GPT-2 model and pretrain it on single GPU.

Tabel of Contents​

In this tutorial we will cover:

  1. Colossal-AI installation
  2. Preparation of Webtext data for GPT-2 training
  3. Steps to apply ZeRO to training GPT-2

Colossal-AI Installation​

You can install Colossal-AI pacakage and its dependencies with PyPI.

pip install colossalai

Define your configuration file (/gpt2_configs/​

Add ZeRO dict in the configuration file, which contains CPU offload and shard strategy settings.

from model_zoo.gpt.gpt import gpt2_small
from colossalai.nn.optimizer import CPUAdam
from import TensorShardStrategy

zero = dict(

Other configs:

SEQ_LEN = 1024

optimizer = dict(

model = dict(

Build GPT-2 model​

In /model, we provide Colossal-AI based GPT models which can be adapt to different parallelism and ZeRO settings. For more details, you can check here.

Prepare data(Webtext dataset)​

We utilize the publicly available OpenWebText library by jcpeterson and eukaryote31's work to download urls to different web pages. We then filtered, cleaned, and deduplicated all downloaded content according to the procedure described in following section.

Install necessary packages​

Note: LSH requires GCC's early version. We have tested that version 9.3.0 works, but version 10.3.0 is not.

pip install ftfy langdetect numpy torch pandas nltk sentencepiece boto3 tqdm regex bs4 newspaper3k htmlmin tldextract cached-path
git clone
cd LSH
python install

If you couldn't install it successfully, you may try to replace the cMinhash.cpp in LSH/lsh with ours, which is provided in tools/lsh/cMinhash.cpp.

Download Data​

  1. Download the deduplicated URLs from jcpeterson.

  2. Unzip the zip file and you will get a folder URLs which consists of many txt files including urls.

  3. Remove blacklisted URLs.

    We appreciate Megatron-LM for making the data preprocessing code public. We have forked Megatron-LM and fixed some bugs. For your convenience, we have collated the needed files in tools/Megatron. Click here to check the source code of Megatron-LM.

    cd path/to/tools
    python Megatron/ <path/to/URLs> <path/to/clean_urls.txt>
  4. Download the content from the clean urls and merge the contents into one loose json file with 1 json per newline of the format {'text': text, 'url': unique_url}.

    We have forked and modified openwebtext as there are some bugs in it. For your convenience, we provide our modified version in tools/download.

    python download/ <path/to/clean_urls.txt> --n_procs 50 --output <path/to/raw.json>

Prepare Data for GPT Training​

  1. Perform ftfy, English detection and remove documents with less than 128 tokens. This step can be sharded and run on shards.

    python Megatron/ <path/to/raw.json> <path/to/clean.json>

    Additional cleanup (e.g. remove documents less than 512 characters or dataset specific cleaning like stories, realnews datasets) can be done using More details can be found by running python --help.

  2. Using LSH, find possible duplicates and store them in a file for later processing. The code supports saving and loading fingerprints for recurrent deduplications, and is also multithreaded for faster processing. More details are can be found by python --help.

    python Megatron/ --inputs <path/to/clean.json> url --output <path/to/process_stage_one.json>
  3. Based on similarity measure defind inside function is_similar (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.

    python Megatron/ <path/to/process_stage_one.json> <path/to/process_stage_two.json>
  4. Remove similar documents that were detected in the last step. The dedup.json is the data after deduplication.

    python Megatron/ <path/to/process_stage_two.json> <path/to/clean.json> <path/to/dedup.json>
  5. shuffle the dataset.

    shuf <path/to/dedup.json> -o <path/to/train_data.json>

Build Webtext dataset(./dataset/​

import json
import os

import torch
from colossalai.registry import DATASETS
from import Dataset
from transformers import GPT2Tokenizer

class WebtextDataset(Dataset):
def __init__(self, path, seq_len=1024) -> None:
root = os.path.dirname(path)
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
if os.path.isfile(encoded_data_cache_path):
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
if seq_len_ == seq_len: = data
self.attention_mask = attention_mask
raw_data = []
with open(path) as f:
for line in f.readlines():
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.unk_token
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') = encoded_data['input_ids']
self.attention_mask = encoded_data['attention_mask'],, self.attention_mask), encoded_data_cache_path)

def __len__(self):
return len(

def __getitem__(self, index):
return {'input_ids':[index],
'attention_mask': self.attention_mask[index]},[index]

Train script(​

Import modules​

ZeRO related module:

from import ZeroInitContext

Other modules:

import contextlib
import os

import colossalai
import colossalai.utils as utils
import torch
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.schedule import (InterleavedPipelineSchedule,
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import LinearWarmupLR
from colossalai.trainer import Trainer, hooks
from colossalai.utils import is_using_pp
from colossalai.utils.timer import MultiTimer
from model_zoo.gpt.gpt import GPTLMLoss

from dataset.webtext import WebtextDataset

Launch Colossal-AI​

parser = colossalai.get_default_parser()
parser.add_argument('--from_torch', default=False, action='store_true')
args = parser.parse_args()
if args.from_torch:

logger = get_dist_logger()

Build Webtext dataloader​'Build data loader', ranks=[0])
train_ds = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LEN)
train_dataloader = utils.get_dataloader(train_ds,

Build ZeRO GPT-2 model​'Build model', ranks=[0])
use_pipeline = is_using_pp()
use_interleaved = hasattr(gpc.config.model, 'num_chunks')
use_zero3 = hasattr(gpc.config, 'zero')

ctx = contextlib.nullcontext()
if use_zero3:
ctx = ZeroInitContext(target_device=torch.cuda.current_device(),,
with ctx:
model = gpc.config.model.pop('type')(**gpc.config.model)
if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList):
model = nn.ModuleList([model])

Define optimizer, loss function and learning rate scheduler​

criterion = getattr(gpc.config, 'loss_fn', None)
if criterion is not None:
criterion = criterion.type()
criterion = GPTLMLoss()'Build optimizer', ranks=[0])
optimizer = gpc.config.optimizer.pop('type')(
model.parameters(), **gpc.config.optimizer)

lr_scheduler = LinearWarmupLR(
optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5)

Start Colossal-AI engine for training​

engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model,

global_batch_size = gpc.config.BATCH_SIZE * \
gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)'Init done, global batch size = {global_batch_size}', ranks=[0])

timier = MultiTimer()

Train: Trainer API​

trainer = Trainer(

hook_list = [
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
# hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
# hooks.LogMemoryByEpochHook(logger),
# hooks.LogTimingByEpochHook(timer, logger),
# hooks.SaveCheckpointHook(checkpoint_dir='./ckpt')

Start training​

DATA is the path where Webtext json file is saved.

Here we pretrain GPT-2 with ZeRO on single GPU, so nproc_per_node=1.

#!/usr/bin/env sh
export DATA=/path/to/train_data.json
torchrun --standalone --nproc_per_node=1 --config=gpt2_configs/ --from_torch