ZeroBubble Pipeline Parallelism
Author: Junwen Duan, Hongxin Liu
Related Paper
Introduction
ZeroBubble (V Schedule): Crucially, splitting B into two stages (also known as an activation gradient and a weight gradient) and a scheme like 1F1B1W can further reduce the bubble compared to the 1F1B scheme in earlier work.
Hands-On Practice
We now demonstrate how to use ZeroBubble with booster API with 4 GPUs.
step 1. Import libraries
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
step 2. Initialize Distributed Environment and Parallism Group
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
step 3. Initialize Module, Optimizer, and Pipeline Schedule
Build our model and Optimizer. We created a Llama with 8 Decoder-Layer. Then, inite the PipelineGraph and Pipeline schedule by get_v_schedule() function.
# Global Param
NUM_BATCH = 8
NUM_TOK_PER_BATCH = 4
NUM_LAYERS = 8
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
# Init Llama from huggingface
configuration = LlamaConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=NUM_LAYERS,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
attn_implementation="flash_attention_2",
)
model = LlamaModel(configuration).cuda()
optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
step 4. Initialize Module, Optimizer, and Pipeline Schedul
Then, we need to create the PipelineGraph and PipelineSchedule using the get_v_schedule() function. We need to initialise the PipelineGraph with the following parameters. x_cost represents the runtime consumed by operation x of each model chunk. x_mem represents the amount of memory consumed by the operation x of each model chunk. These parameters are estimated and filled in before the pipeline starts. In fact, better results can be obtained based on the runtime and memory cost during the real computation of the model. In the following example, we assume that the computation times for the model's forward, reverse B, and reverse W are 1, 1, 1, respectively, and the p2p communication time is 1.
# Init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
n_stage=pp_size,
n_micro=num_microbatches,
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
)
zbv_schedule = graph.get_v_schedule()
step 5.Init Booster
Pass pp_style="zbv" when initialising the Plugin to use the ZeroBubble Pipeline.
plugin = HybridParallelPlugin(
pp_size=4,
num_microbatches=4,
tp_size=1,
sp_size=1,
zero_stage=1,
initial_scale=1,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
dp_size = plugin.dp_size
booster = Booster(plugin=plugin)
step 6.Train Your Model
steps = 10
for step in range(steps):
input_embeddings = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(
input_embeddings, group=plugin.pp_group
)
data_iter = iter([{"inputs_embeds": input_embeddings}])
output = booster.execute_pipeline(
data_iter,
model,
lambda x, y: x.last_hidden_state.mean(),
optimizer,
return_loss=True,
return_outputs=True,
)
optimizer.step()
optimizer.zero_grad()
Advanced Practice
In ColossalAI, you can get better training performance by using MetaCache and HybridParallel with ZeroBubble.
1.Use MetaCache with ZeroBubble
Pass "enable_metadata_cache=True" when initialising the Plugin to use the Meta Cache with ZeroBubble Pipeline.
plugin = HybridParallelPlugin(
pp_size=2,
num_microbatches=4,
tp_size=2,
sp_size=2,
zero_stage=1,
initial_scale=1,
enable_metadata_cache=True,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
2.HybridParallel with ZeroBubble
Pass pp_size, tp_size, sp_size when initialising the Plugin to use the HybridParallel with ZeroBubble Pipeline.
plugin = HybridParallelPlugin(
pp_size=2,
num_microbatches=2,
tp_size=2,
sp_size=2,
zero_stage=1,
initial_scale=1,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
Performance Benchmark
| HybridParallel Strategy | Pipeline Parallel | Sequence Parallel + Pipeline Parallel | Data Parallel + Pipeline Parallel | |||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| With 1F1B | 15.27 samples/sec | 17.22 samples/sec | 14.06 samples/sec | |||||||||||||||||||||||||||||||||||
| With Zero Bubble | 17.36 samples/sec | 18.38 samples/sec | 14.44 samples/sec | |||||||||||||||||||||||||||||||||||
3.Fine-tuning Scheduler parameters
Model compatibility
| Shardformer/Model | Bert | Blip2 | Bloom | Chatglm2 | Command | Deepseek | Falcon | GPT2 | Gptj | Llama | Mistral | Opt | Qwen2 | Sam | T5 | Vit | Whisper | |||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| ZeroBubble | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ||||||||||||||||||||||
API Reference
class
colossalai.pipeline.ZeroBubbleVPipeScheduler
- stage_manager (PipelineStageManager) -- If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.
- schedule (List[ScheduledNode]) -- Schedule for ZeroBubbleVPipe.
- num_model_chunks (int) -- The number of model chunk in a device.
- num_microbatch (Optional[int]) -- The number of microbatch.
- microbatch_size (Optional[int]) -- The size per microbatch.
- enable_metadata_cache (bool) -- whether to enable metadata cache to acclerate communication.
- overlap_p2p (bool) -- whether to use overlap_p2p.
ZeroBubbleVPipeScheduler
function
backward_b_step
- model_chunk (ModuleList or Module) -- Model Chunk to be run;
- model_chunk_id (int) -- The current model chunk idx;
- optimizer (OptimizerWrapper) -- Optimizer to update the model
- input_obj (Optional[Tuple(dict)]) -- x. (microbatch, input_obj)
- output_obj (Union[dict, torch.Tensor]) -- y.
- output_obj_grad (dict) -- dy.
Optional[dict]: dx.
function
backward_w_step
- model_chunk (ModuleList or Module) -- Model Chunk to be run;
- model_chunk_id (int) -- The current model chunk idx;
- optimizer (OptimizerWrapper) -- Optimizer to update the model
- output_obj (Union[dict, torch.Tensor]) -- y.
- output_obj_grad (dict) -- dy.
: Nothing need to return; we only calculate dw then update w;
function
forward_backward_step
- model_chunk (ModuleList or Module) -- Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
- data_iter (Iterable) -- Data iterator.
- criterion (Callable[[Any, Any], Tensor]) -- Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
- optimizer (OptimizerWrapper, optional) -- Optimizer to be used. Can be None when only forward is executed. Defaults to None.
- return_loss (bool, optional) -- Whether to return loss. Defaults to False. Whether to return loss.
- return_outputs (bool, optional) -- Whether to return model outputs. Defaults to False. Whether to return model outputs.
dict: A dict with keys: 'loss' and 'outputs'.
function
forward_step
- model_chunk (ModuleList or Module) -- Model Chunk to be run;
- model_chunk_id (int) -- The current model chunk idx;
- input_obj (Optional[dict]) -- x;
- criterion (Callable) -- loss function;
- accum_loss (Optional[torch.Tensor], optional) -- Accumulated loss. Defaults to None.
- outputs (Optional[List[Any]], optional) -- List to store the output of the last stage (final output). Defaults to None.
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
function
get_model_chunk_id
- microbatch_id (int) -- the current microbatch idx
- forward (bool) -- if is the forward process
int: The model chunk idx of the input microbatch_id
function
load_batch
- data_iter (Iterable) -- Data iterator.
- device (Optional[torch.device], optional) -- Target device. Defaults to None.
function
load_micro_batch
- microbatch_id (int) -- the current model chunk idx.
Any: Micro batch.
function
recv_backward
- model_chunk_id (int) -- The current model chunk idx.
- next_rank (int, optional) -- The rank of the source of the tensor.
Any: The input gradient tensor or gradient tensor list. Any: The wait handles for the communication.
function
recv_forward
- model_chunk_id (int) -- The current model chunk idx.
- prev_rank (int, optional) -- The rank of the source of the tensor.
Any: The input tensor or input tensor list. Any: The wait handles for the communication.
function
run_forward_backward
Runs Zerobubble schedule, with communication between pipeline stages.
function
schedule_b
scheduled_node --
- model_chunk (ModuleList or Module) -- Model Chunk to be run;
- model_chunk_id (int) -- The current model chunk idx;
: Nothing.
function
schedule_f
scheduled_node --
- model_chunk (ModuleList or Module) -- Model Chunk to be run;
- model_chunk_id (int) -- The current model chunk idx;
- criterion (Callable) -- loss function;
- accum_loss (Optional[torch.Tensor], optional) -- Accumulated loss. Defaults to None.
- outputs (Optional[List[Any]], optional) -- List to store the output of the last stage (final output). Defaults to None.
: Nothing.
function
schedule_w
scheduled_node --
- model_chunk (ModuleList or Module) -- Model Chunk to be run;
- model_chunk_id (int) -- The current model chunk idx;
: Nothing.
function
send_backward
- model_chunk_id (int) -- The current model chunk idx.
- prev_rank (int, optional) -- The rank of the recipient of the tensor
Any: The wait handles for the communication.
function
send_forward
- model_chunk_id (int) -- The current model chunk idx.
- next_rank (int, optional) -- The rank of the recipient of the tensor.
Any: The wait handles for the communication.