Skip to main content

ZeroBubble Pipeline Parallelism

Author: Junwen Duan, Hongxin Liu

Related Paper

Introduction

ZeroBubble (V Schedule): Crucially, splitting B into two stages (also known as an activation gradient and a weight gradient) and a scheme like 1F1B1W can further reduce the bubble compared to the 1F1B scheme in earlier work.

Hands-On Practice

We now demonstrate how to use ZeroBubble with booster API with 4 GPUs.

step 1. Import libraries

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel

import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler

step 2. Initialize Distributed Environment and Parallism Group

colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")

step 3. Initialize Module, Optimizer, and Pipeline Schedule

Build our model and Optimizer. We created a Llama with 8 Decoder-Layer. Then, inite the PipelineGraph and Pipeline schedule by get_v_schedule() function.

# Global Param
NUM_BATCH = 8
NUM_TOK_PER_BATCH = 4
NUM_LAYERS = 8
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
# Init Llama from huggingface
configuration = LlamaConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=NUM_LAYERS,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
attn_implementation="flash_attention_2",
)
model = LlamaModel(configuration).cuda()
optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)

step 4. Initialize Module, Optimizer, and Pipeline Schedul

Then, we need to create the PipelineGraph and PipelineSchedule using the get_v_schedule() function. We need to initialise the PipelineGraph with the following parameters. x_cost represents the runtime consumed by operation x of each model chunk. x_mem represents the amount of memory consumed by the operation x of each model chunk. These parameters are estimated and filled in before the pipeline starts. In fact, better results can be obtained based on the runtime and memory cost during the real computation of the model. In the following example, we assume that the computation times for the model's forward, reverse B, and reverse W are 1, 1, 1, respectively, and the p2p communication time is 1.

# Init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
n_stage=pp_size,
n_micro=num_microbatches,
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
)
zbv_schedule = graph.get_v_schedule()

step 5.Init Booster

Pass pp_style="zbv" when initialising the Plugin to use the ZeroBubble Pipeline.

plugin = HybridParallelPlugin(
pp_size=4,
num_microbatches=4,
tp_size=1,
sp_size=1,
zero_stage=1,
initial_scale=1,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)

dp_size = plugin.dp_size
booster = Booster(plugin=plugin)

step 6.Train Your Model

steps = 10
for step in range(steps):
input_embeddings = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(
input_embeddings, group=plugin.pp_group
)
data_iter = iter([{"inputs_embeds": input_embeddings}])
output = booster.execute_pipeline(
data_iter,
model,
lambda x, y: x.last_hidden_state.mean(),
optimizer,
return_loss=True,
return_outputs=True,
)
optimizer.step()
optimizer.zero_grad()

Advanced Practice

In ColossalAI, you can get better training performance by using MetaCache and HybridParallel with ZeroBubble.

1.Use MetaCache with ZeroBubble

Pass "enable_metadata_cache=True" when initialising the Plugin to use the Meta Cache with ZeroBubble Pipeline.

plugin = HybridParallelPlugin(
pp_size=2,
num_microbatches=4,
tp_size=2,
sp_size=2,
zero_stage=1,
initial_scale=1,
enable_metadata_cache=True,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)

2.HybridParallel with ZeroBubble

Pass pp_size, tp_size, sp_size when initialising the Plugin to use the HybridParallel with ZeroBubble Pipeline.

plugin = HybridParallelPlugin(
pp_size=2,
num_microbatches=2,
tp_size=2,
sp_size=2,
zero_stage=1,
initial_scale=1,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)

Performance Benchmark

HybridParallel StrategyPipeline ParallelSequence Parallel + Pipeline ParallelData Parallel + Pipeline Parallel
With 1F1B15.27 samples/sec17.22 samples/sec14.06 samples/sec
With Zero Bubble17.36 samples/sec18.38 samples/sec14.44 samples/sec

3.Fine-tuning Scheduler parameters


Model compatibility

Shardformer/ModelBertBlip2BloomChatglm2CommandDeepseekFalconGPT2GptjLlamaMistralOptQwen2SamT5VitWhisper
ZeroBubble✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️

API Reference

class
 

colossalai.pipeline.ZeroBubbleVPipeScheduler

(stage_manager: PipelineStageManager, schedule: typing.List[colossalai.pipeline.schedule.v_schedule.ScheduledNode], num_model_chunks: int, num_microbatch: typing.Optional[int] = None, microbatch_size: typing.Optional[int] = None, enable_metadata_cache: bool = True, overlap_p2p: bool = True)
Parameters
  • stage_manager (PipelineStageManager) -- If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.
  • schedule (List[ScheduledNode]) -- Schedule for ZeroBubbleVPipe.
  • num_model_chunks (int) -- The number of model chunk in a device.
  • num_microbatch (Optional[int]) -- The number of microbatch.
  • microbatch_size (Optional[int]) -- The size per microbatch.
  • enable_metadata_cache (bool) -- whether to enable metadata cache to acclerate communication.
  • overlap_p2p (bool) -- whether to use overlap_p2p.
Description

ZeroBubbleVPipeScheduler

function
 

backward_b_step

(model_chunk: typing.Union[torch.nn.modules.container.ModuleList, torch.nn.modules.module.Module], model_chunk_id: int, optimizer: OptimizerWrapper, input_obj: typing.Optional[dict], output_obj: typing.Union[dict, torch.Tensor], output_obj_grad: typing.Optional[dict])
Parameters
  • model_chunk (ModuleList or Module) -- Model Chunk to be run;
  • model_chunk_id (int) -- The current model chunk idx;
  • optimizer (OptimizerWrapper) -- Optimizer to update the model
  • input_obj (Optional[Tuple(dict)]) -- x. (microbatch, input_obj)
  • output_obj (Union[dict, torch.Tensor]) -- y.
  • output_obj_grad (dict) -- dy.
Returns

Optional[dict]: dx.

Description
Backward dx step of the pipeline; we calculate "dx = w*dy" here;
function
 

backward_w_step

(model_chunk: typing.Union[torch.nn.modules.container.ModuleList, torch.nn.modules.module.Module], model_chunk_id: int, optimizer: OptimizerWrapper, output_obj: typing.Union[dict, torch.Tensor], output_obj_grad: typing.Optional[dict])
Parameters
  • model_chunk (ModuleList or Module) -- Model Chunk to be run;
  • model_chunk_id (int) -- The current model chunk idx;
  • optimizer (OptimizerWrapper) -- Optimizer to update the model
  • output_obj (Union[dict, torch.Tensor]) -- y.
  • output_obj_grad (dict) -- dy.
Returns

: Nothing need to return; we only calculate dw then update w;

Description
Backward dw step of the pipeline; we calculate "dw = x*dy" here;
function
 

forward_backward_step

(model_chunk: typing.Union[torch.nn.modules.container.ModuleList, torch.nn.modules.module.Module], data_iter: typing.Iterable, criterion: typing.Callable[..., typing.Any], optimizer: typing.Optional[colossalai.interface.optimizer.OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False)
Parameters
  • model_chunk (ModuleList or Module) -- Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
  • data_iter (Iterable) -- Data iterator.
  • criterion (Callable[[Any, Any], Tensor]) -- Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
  • optimizer (OptimizerWrapper, optional) -- Optimizer to be used. Can be None when only forward is executed. Defaults to None.
  • return_loss (bool, optional) -- Whether to return loss. Defaults to False. Whether to return loss.
  • return_outputs (bool, optional) -- Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns

dict: A dict with keys: 'loss' and 'outputs'.

Description
function
 

forward_step

(model_chunk: typing.Union[torch.nn.modules.container.ModuleList, torch.nn.modules.module.Module], model_chunk_id: int, micro_batch: typing.Optional[dict], input_obj: typing.Optional[dict], criterion: typing.Callable, accum_loss: typing.Optional[torch.Tensor] = None, outputs: typing.Optional[typing.List[typing.Any]] = None)
Parameters
  • model_chunk (ModuleList or Module) -- Model Chunk to be run;
  • model_chunk_id (int) -- The current model chunk idx;
  • input_obj (Optional[dict]) -- x;
  • criterion (Callable) -- loss function;
  • accum_loss (Optional[torch.Tensor], optional) -- Accumulated loss. Defaults to None.
  • outputs (Optional[List[Any]], optional) -- List to store the output of the last stage (final output). Defaults to None.
Returns

Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).

Description
Forward one step of the pipeline
function
 

get_model_chunk_id

(microbatch_id: int, is_forward: bool)
Parameters
  • microbatch_id (int) -- the current microbatch idx
  • forward (bool) -- if is the forward process
Returns

int: The model chunk idx of the input microbatch_id

Description
Helper method to get the model chunk ID given the iteration number.
function
 

load_batch

(data_iter: typing.Iterable, device: typing.Optional[torch.device] = None)
Parameters
  • data_iter (Iterable) -- Data iterator.
  • device (Optional[torch.device], optional) -- Target device. Defaults to None.
Description
Load a batch from data iterator.
function
 

load_micro_batch

(model_chunk_id: int)
Parameters
  • microbatch_id (int) -- the current model chunk idx.
Returns

Any: Micro batch.

Description
Load a micro batch from the current batch.
function
 

recv_backward

(model_chunk_id: int, next_rank: int = None)
Parameters
  • model_chunk_id (int) -- The current model chunk idx.
  • next_rank (int, optional) -- The rank of the source of the tensor.
Returns

Any: The input gradient tensor or gradient tensor list. Any: The wait handles for the communication.

Description
Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. For ZBV.
function
 

recv_forward

(model_chunk_id: int, prev_rank: int = None)
Parameters
  • model_chunk_id (int) -- The current model chunk idx.
  • prev_rank (int, optional) -- The rank of the source of the tensor.
Returns

Any: The input tensor or input tensor list. Any: The wait handles for the communication.

Description
Copy the forward output from the previous stage in pipeline as the input tensor of this stage. For ZBV.
function
 

run_forward_backward

(model_chunk: typing.Union[torch.nn.modules.container.ModuleList, torch.nn.modules.module.Module], data_iter: typing.Iterable, criterion: typing.Callable[..., typing.Any], optimizer: typing.Optional[colossalai.interface.optimizer.OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False)
Description

Runs Zerobubble schedule, with communication between pipeline stages.

function
 

schedule_b

(scheduled_node, model_chunk: typing.Union[torch.nn.modules.container.ModuleList, torch.nn.modules.module.Module], model_chunk_id: int, optimizer: OptimizerWrapper)
Parameters

scheduled_node --

  • model_chunk (ModuleList or Module) -- Model Chunk to be run;
  • model_chunk_id (int) -- The current model chunk idx;
Returns

: Nothing.

Description
A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
function
 

schedule_f

(scheduled_node, model_chunk: ModuleList, model_chunk_id: int, criterion: typing.Callable, accum_loss: typing.Optional[torch.Tensor] = None, outputs: typing.Optional[typing.List[typing.Any]] = None)
Parameters

scheduled_node --

  • model_chunk (ModuleList or Module) -- Model Chunk to be run;
  • model_chunk_id (int) -- The current model chunk idx;
  • criterion (Callable) -- loss function;
  • accum_loss (Optional[torch.Tensor], optional) -- Accumulated loss. Defaults to None.
  • outputs (Optional[List[Any]], optional) -- List to store the output of the last stage (final output). Defaults to None.
Returns

: Nothing.

Description
A complete forward schedule; Include recv fwd --> cal fwd --> send fwd;
function
 

schedule_w

(scheduled_node, model_chunk: typing.Union[torch.nn.modules.container.ModuleList, torch.nn.modules.module.Module], model_chunk_id: int, optimizer: OptimizerWrapper)
Parameters

scheduled_node --

  • model_chunk (ModuleList or Module) -- Model Chunk to be run;
  • model_chunk_id (int) -- The current model chunk idx;
Returns

: Nothing.

Description
A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w);
function
 

send_backward

(model_chunk_id: int, prev_rank: int = None)
Parameters
  • model_chunk_id (int) -- The current model chunk idx.
  • prev_rank (int, optional) -- The rank of the recipient of the tensor
Returns

Any: The wait handles for the communication.

Description
Sends the gradient tensor to the previous stage in pipeline. For ZBV.
function
 

send_forward

(model_chunk_id: int, next_rank: int = None)
Parameters
  • model_chunk_id (int) -- The current model chunk idx.
  • next_rank (int, optional) -- The rank of the recipient of the tensor.
Returns

Any: The wait handles for the communication.

Description
Sends the input tensor to the next stage in pipeline. For ZBV.