Skip to main content

Shardformer

Author: Baizhou Zhang, Bin Jia

Prerequisite

Example Code

Related Paper

Introduction

When training large transformer models such as LLaMa-2 70B or OPT 175B, model parallelism methods that divide a huge model into smaller shards, including tensor parallelism or pipeline parallelism, are essential so as to meet the limitation of GPU memory. However, manually cutting model and rewriting its forward/backword logic could be difficult for users who are not familiar with distributed training. Meanwhile, the Huggingface transformers library has gradually become users' first choice of model source, and most mainstream large models have been open-sourced in Huggingface transformers model library.

Out of this motivation, the ColossalAI team develops Shardformer, a feature that automatically does preparation of model parallelism (tensor parallelism/pipeline parallelism) for popular transformer models in HuggingFace. This module aims to make parallelization hassle-free for users who are not from the system background. Within a few lines of codes, users can turn a model into a state ready for distributed training. Also, Shardformer contains various optimization tools for acceleration and memory saving during forward/backward pass.

Supporting Information

Model/Feature Compatibility Matrix:

Model/FeatureTensor
Parallel
Pipeline
Parallel
Lazy
Initialization
xFormersFlash
Attention 2
JIT Fused
Operators
Fused
LayerNorm
Sequence
Parallel
Sequence
Overlap
Llama V1/V2✔️✔️✔️✔️✔️✔️✔️
OPT✔️✔️✔️✔️✔️✔️✔️
BLOOM✔️✔️✔️✔️✔️✔️✔️✔️✔️
ChatGLM 2✔️✔️✔️✔️✔️✔️✔️✔️✔️
BERT✔️✔️✔️✔️✔️✔️✔️✔️✔️
GPT 2✔️✔️✔️✔️✔️✔️✔️✔️✔️
T5✔️✔️✔️✔️✔️✔️✔️
ViT✔️✔️✔️✔️✔️✔️
Whisper✔️✔️✔️✔️✔️✔️
SAM✔️✔️✔️✔️✔️
Blip2✔️✔️✔️✔️✔️
Falcon✔️✔️✔️✔️✔️✔️

List of model families we plan to support in the near future:

  • RoBERTa
  • ALBERT
  • ERNIE
  • GPT Neo
  • GPT-J
  • BEiT
  • SwinTransformer V1/V2
  • qwen

The support matrix will grow larger as more models and optimization tools emerge in the future. If you have any suggestions on the models/optimization we should support, please feel free to mention it in Issues section of our project.

Usage

Shardformer Configuration

The configuration of Shardformer is controlled by class ShardConfig:

class
 

colossalai.shardformer.ShardConfig

(tensor_parallel_process_group: typing.Optional[torch.distributed.distributed_c10d.ProcessGroup] = None, pipeline_stage_manager: typing.Optional[colossalai.pipeline.stage_manager.PipelineStageManager] = None, enable_tensor_parallelism: bool = True, enable_fused_normalization: bool = False, enable_flash_attention: bool = False, enable_jit_fused: bool = False, enable_all_optimization: bool = False, enable_sequence_parallelism: bool = False, enable_sequence_overlap: bool = False, parallel_output: bool = True, extra_kwargs: typing.Dict[str, typing.Any] = <factory>)
Parameters
  • tensor_parallel_process_group (Optional[ProcessGroup]) -- The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group.
  • pipeline_stage_manager (Optional[PipelineStageManager]) -- If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.
  • enable_tensor_parallelism (bool) -- Whether to use tensor parallelism. Defaults to True.
  • enable_fused_normalization (bool) -- Whether to use fused layernorm. Defaults to False.
  • enable_flash_attention (bool, optional) -- Whether to switch on flash attention. Defaults to False.
  • enable_jit_fused (bool, optional) -- Whether to switch on JIT fused operators. Defaults to False.
  • enable_sequence_parallelism (bool) -- Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
  • enable_sequence_overlap (bool) -- Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
  • enable_all_optimization (bool) -- Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
Description

The config for sharding the huggingface model

If you want to enable Apex Fused Layernorm, please install apex. If you want to enable the usage of flash attention, please install flash_attn. In addition, xFormers's cutlass_op can serve as a backup for flash attention.

Enabling Shardformer

Enabling Shardformer through Booster initialized with HybridParallelPlugin is the recommended way to awaken the power of Shardformer. The main reason is that pipeline parallelism cannot successfully work without the calling of execute_pipeline method of Booster. Besides, HybridParallelPlugin provides the capacity to combine the features of Shardformer with other useful features, such as mixed precision training or Zero.

Here is an example on how to trigger Shardformer through HybridParallelPlugin. Move to the root directory of this example, and execute

torchrun --standalone --nproc_per_node 4  finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert"

Then you can start finetuning a bert model wrapped by Shardformer. The process of wrapping is operated by HybridParallelPlugin.

Let's delve into the code of finetune.py:

In the main function, the plugin is created through the following codes:

...
elif args.plugin == "hybrid_parallel":
# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(
tp_size=1,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
zero_stage=1,
precision="fp16",
initial_scale=1,
)

Here you can change the configuration of plugin by setting tp_size, pp_size or zero_stage to other values. More details about plugin configuration can be found in Booster Plugins Doc.

If pipeline parallel is not enabled, just do the training in the same way of other booster plugins(first boost with Booster, then do forward and backward through normal way). However, if pipeline parallel is enabled, there are several usages different from other normal cases:

  1. Before doing forward or backward, the criterion function (loss function) is processed to meet the argument demand of running pipeline:

    def _criterion(outputs, inputs):
    outputs = output_transform_fn(outputs)
    loss = criterion(outputs)
    return loss
  2. In train_epoch function, dataloader is converted into Iterator class before running pipeline:

    train_dataloader_iter = iter(train_dataloader)
  3. Do forward and backward passing through calling Booster.execute_pipeline method:

    outputs = booster.execute_pipeline(
    train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
    )

    Backward passing has been completed by this method, so there is no need to call loss.backward() after executing this method. More details about Booster.execute_pipeline can be found in Booster API Doc.

You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without Booster.

Here is an example on how to trigger Shardformer through calling Shardformer APIs. In the train function of example code, the model is wrapped by Shardformer through the following few codes:

...
if dist.get_world_size() > 1:
tp_group = dist.new_group(backend="nccl")

# First create configuration for Shardformer
shard_config = ShardConfig(
tensor_parallel_process_group=tp_group,
enable_tensor_parallelism=True,
enable_all_optimization=True
)

# Then create ShardFormer object with created config
shard_former = ShardFormer(shard_config=shard_config)

# Finally shard the model using ShardFormer.optimize method
model, _ = shard_former.optimize(model)
...

Precautions

  1. When enabling pipeline parallel, please don't do the forward/backward pass in the conventional way (model(input), loss.backward()), which will cause unexpected errors. Rather, please do forward/backward pass through calling booster.execute_pipeline method.

  2. When you use Shardformer to process classification models such as GPT2ForSequenceClassification, ViTForImageClassification, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer.

  3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through

    from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
    from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel

    when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes.

How Shardformer Works

Main Idea

Generally, Shardformer works through the following four kinds of replacements:

  1. Replacing original PyTorch module (e.g. nn.Linear, nn.Embedding) with a crafted distributed module. The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters. Also, new forward methods will replace original ones so as to execute distributed computation, such as linear layers' split /gather operations executed under tensor parallelism. Each distributed module implements its from_native_module static method to convert the PyTorch module to its corresponding distributed module.

  2. Replacing attributes of original Huggingface Transformers layers with appropriate attributes for distributed training. For example, when training LlaMa-2 with tensor parallel size as 2, the attribute num_heads of LlamaDecoderLayer (the number of attention heads in each layer) should be replaced with model.config.num_attention_heads // 2.

  3. Replacing the forward methods implemented by original Huggingface Transformers libraries with our customized forward methods. This replacement is essential for pipeline parallelism, where a customized function is needed to pass intermediate hidden states between different pipeline stages. Also, optimization methods such as flash attention or sequence parallel can be injected into the forward process through our customized forward method.

  4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer). By executing ModelSharder.shard method, current device will only keep the part of model parameters it's supposed to take care of. To be specific, they should be the assigned parameter shards when using tensor parallelism, or the parameters belonging to current pipeline stage when using pipeline parallelism, or both of them. All other parameters are released so as to liberate memory usage. As a result, the optimizer will only compute the states corresponding to these part of parameters, causing the usage of memory to be further saved.

All of these replacements are implemented with manually written policies and forward functions. If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our Shardformer development document and pipeline parallelism design for more details.

Sequence Parallelism

Sequence parallelism is a special optimization method supported by Shardformer. Sequence parallelism in Shardformer is a little different from this one which focuses on ring attention. In Shardformer, sequence parallelism is only used along with 1D tensor parallelism to further reduce memory occupation of activation tensors during computation.

  1. In normal 1D tensor parallel, there are 2 communication operations, gg and g\vec{g}, gg will do one time All-Reduce in backward to get all gradients from all the devices and g\vec{g} will do one time All-Reduce in forward to get whole outputs from all the devices.

  2. When using sequence parallelism, g\vec{g} needs to do All-Gather to gather the inputs along sequence dimension during forward, and Reduce-Scatter to split the gradient during backward. g\vec{g} needs to do Reduce-Scatter to split the output of Row Linear layer of tensor parallel to all devices along sequence dimension, and All-Gather to get the whole gradient during backward.

  3. NCCL's implementation of All-Reduce adopts the Ring All-Reduce approach, which consists of a Reduce-Scatter operation and an All-Gather operation with equal costs. Therefore, compared with sequence parallelism and tensor parallelism, it does not introduce additional communication overhead.

  4. One important thing to note is that when using sequence parallelism along with Column Linear module of tensor parallelism, the complete input needs to be obtained during the backward computation of gradients. During the forward pass, only the portion of the input that is split along the sequence dimension is retained, in the shape of (batch,sequencelen/k,hiddenstates)(batch, sequence_len/k, hidden_states). Therefore, an additional All-Gather operation is required to obtain the complete input for gradient computation. However, it is possible to overlap the gradient computation with the All-Gather communication operation in our implementation, which would not introduce additional communication overhead (corresponding to the enable_sequence_overlap parameter in Shardformer).