Skip to main content

2D Tensor Parallelism

Author: Zhengda Bian, Yongbin Li

Prerequisite

Example Code

Related Paper

Introduction

1D tensor parallelism does not partition activations, which can also consume a great amount of memory in terms of large-scale models. To evenly distribute the computation and memory load, an efficient 2D tensor parallelism algorithm was introduced based on SUMMA (Scalable Universal Matrix Multiplication Algorithm).

Let's still take a linear layer Y=XAY = XA as an example. Given P=q×qP=q\times q processors (necessary condition), e.g. q=2q=2, we split both the input XX and weight AA into

[X00X01X10X11] and [A00A01A10A11].\left[\begin{matrix} X_{00} & X_{01} \\ X_{10} & X_{11} \end{matrix} \right] \text{~and~} \left[\begin{matrix} A_{00} & A_{01} \\ A_{10} & A_{11} \end{matrix} \right].

The calculation includes qq steps. When t=1t=1, Xi0X_{i0} is broadcasted in its row, and A0jA_{0j} is broadcasted in its column. So, we have

[X00,A00X00,A01X10,A00X10,A01].\left[\begin{matrix} X_{00},A_{00} & X_{00},A_{01} \\ X_{10},A_{00} & X_{10},A_{01} \end{matrix} \right].

Then we multiply Xi0X_{i0} and A0jA_{0j} on each processor (i,j)(i, j) as

[X00A00X00A01X10A00X10A01](1).\left[\begin{matrix} X_{00}A_{00} & X_{00}A_{01} \\ X_{10}A_{00} & X_{10}A_{01} \end{matrix} \right] (1).

Similarly, when t=2t=2, Xi1X_{i1} is broadcasted in its row, A1jA_{1j} is broadcasted in its column, and we multiply them as

[X01A10X01A11X11A10X11A11](2).\left[\begin{matrix} X_{01}A_{10} & X_{01}A_{11} \\ X_{11}A_{10} & X_{11}A_{11} \end{matrix} \right] (2).

By adding (1)(1) and (2)(2) up, we have

Y=XA=[X00A00+X01A10X00A01+X01A11X10A00+X11A10X10A01+X11A11].Y = XA = \left[\begin{matrix} X_{00}A_{00}+X_{01}A_{10} & X_{00}A_{01}+X_{01}A_{11} \\ X_{10}A_{00}+X_{11}A_{10} & X_{10}A_{01}+X_{11}A_{11} \end{matrix} \right].

Efficiency

Given P=q×qP=q\times q processors, we present the theoretical computation and memory cost, as well as the communication cost based on the ring algorithm in both the forward and backward pass of 2D tensor parallelism.

ComputationMemory (parameters)Memory (activations)Communication (bandwidth)Communication (latency)
O(1/q2)O(1/q^2)O(1/q2)O(1/q^2)O(1/q2)O(1/q^2)O(6(q1)/q)O(6(q-1)/q)O(6(q1))O(6(q-1))

Usage

Currently the newest version of ColossalAI doesn't support 2D tensor parallelism, but this feature will be integrated into Shardformer in future releases. For more details about ideas and usages of Shardformer, please refer to Shardformer Doc.

For users of older version of ColossalAI, please refer to ColossalAI-Examples - 2D Tensor Parallelism.