矩阵乘法是线性代数中最重要的运算之一。在机器学习中,矩阵乘法也是经常用到的运算,最常见于 MLP 线性层。
而在实际的模型训练和推理系统中,模型参数和中间激活的张量可能非常大,而 GPU 显存空间有限。因此,我们需要将张量切分为多个块,以在 GPU 上实现并行计算。而这和分块矩阵的乘法有着紧密的联系。
分块矩阵的定义
一个 m×n 维的矩阵 W 定义如下
Wm×n=w11w21⋮wm1w12w22⋮wm2⋯⋯⋱⋯w1nw2n⋮wmn现在将这个矩阵在行方向上划分为 X 个块,在列方向上划分为 Y 个块。这样,每个块的维度为 x×y,其中
⎩⎨⎧x=Xmy=Yn于是,分块后的矩阵可记为
WX×Y=W11W21⋮WX1W12W22⋮WX2⋯⋯⋱⋯W1YW2Y⋮WXY其中,每个块 Wij (1⩽i,j⩽X,Y) 可以表示为
Wij=w(i−1)x+1,(j−1)y+1w(i−1)x+2,(j−1)y+1⋮wix,(j−1)y+1w(i−1)x+1,(j−1)y+2w(i−1)x+2,(j−1)y+2⋮wix,(j−1)y+2⋯⋯⋱⋯w(i−1)x+1,jyw(i−1)x+2,jy⋮wix,jy向量与分块矩阵相乘
设一个 m 维的向量
λ=k1k2⋮km要使向量 λ 右乘分块矩阵 W,需要相应地将向量 λ 分成 X 个块,即
λ=λ1λ2⋮λX其中,每个块 λi (1⩽i⩽X) 可以表示为
λi=k(i−1)x+1k(i−1)x+2⋮kix因此,向量 λ 右乘分块矩阵 W 可以表示为
λ⊤W=[i=1∑Xλi⊤Wi1i=1∑Xλi⊤Wi2⋯i=1∑Xλi⊤Wiy]矩阵与分块矩阵相乘
设一个 d×m 维的矩阵
Ad×m=a11a21⋮ad1a12a22⋮ad2⋯⋯⋱⋯a1ma2m⋮adm仿照向量与分块矩阵相乘的方式,如果对矩阵 A 作 1D 切分,即
A=[A1A2⋯AX]其中,每个分块矩阵的维度是 d×x。那么,矩阵 A 右乘分块矩阵 W 可以表示为
AW=[i=1∑XAiWi1i=1∑XAiWi2⋯i=1∑XAiWiy]如果对矩阵 A 作 2D 切分,在行方向上切分成 T 个块,在列方向上切分成 X 个块。这样,每个块的维度就变成 t×x,其中
t=Td那么矩阵 A 可表示为
AT×X=A11A21⋮AT1A12A22⋮AT2⋯⋯⋱⋯A1XA2X⋮ATX这样,矩阵 A 右乘分块矩阵 W 可以表示为
AW=i=1∑XA1iWi1i=1∑XA2iWi1⋮i=1∑XATiWi1i=1∑XA1iWi2i=1∑XA2iWi2⋮i=1∑XATiWi2⋯⋯⋱⋯i=1∑XA1iWiyi=1∑XA2iWiy⋮i=1∑XATiWiy分块矩阵乘法的一般规律
与普通的矩阵乘法相比,分块矩阵乘法多了一个步骤:分块累加。用规范化的数学语言描述,即
Am×pBp×nAB=A11A21⋮AM1A12A22⋮AM2⋯⋯⋱⋯A1PA2P⋮AMP=B11B21⋮BP1B12B22⋮BP2⋯⋯⋱⋯B1NB2N⋮BPN=i=1∑PA1iBi1i=1∑PA2iBi1⋮i=1∑PAMiBi1i=1∑PA1iBi2i=1∑PA2iBi2⋮i=1∑PAMiBi2⋯⋯⋱⋯i=1∑PA1iBiNi=1∑PA2iBiN⋮i=1∑PAMiBiN其中,矩阵 A 被 2D 切分为 M×P 份,矩阵 B 被 2D 切分为 P×N 份。
而至于向量与分块矩阵的乘法,以及 1D 和 2D 切分的情况,都属于上述一般规律的特例。是否需要将分块累加,取决于是否对 p 维进行切分。
图解分块矩阵乘法
设想有 2 个矩阵都被切分成 2×2 的块,每个块在 4 张不同的 GPU 上进行矩阵乘法运算。则每个 GPU 上的计算分别如下面 4 张图所示:




分块矩阵乘法的意义
矩阵乘法是线性代数中最重要的运算之一。在机器学习中,矩阵乘法也是经常用到的运算,最常见于 MLP 线性层。
而在实际的模型训练和推理系统中,模型参数和中间激活的张量可能非常大,而 GPU 显存空间有限。因此,我们需要将张量切分为多个块,以在 GPU 上实现并行计算。
而分块矩阵乘法的意义在于,分块意味着一个完整的张量可以进行切分,在 GPU 上进行并行的矩阵乘法计算从理论上说是可行的。
不过,如果我们将矩阵在 p 维上切分,最终的计算结果需要沿着 p 维进行聚合。这样的操作会带来额外的通信开销。
此外,最终的结果从整体上看是多个小块矩阵的拼接,这意味着每个 GPU 可以存储相应的不同的结果。至于是否存在通信开销,取决于后面的算子是否需要完整的张量。