分块矩阵的乘法
矩阵乘法是线性代数中最重要的运算之一。在机器学习中,矩阵乘法也是经常用到的运算,最常见于 MLP 线性层。
而在实际的模型训练和推理系统中,模型参数和中间激活的张量可能非常大,而 GPU 显存空间有限。因此,我们需要将张量切分为多个块,以在 GPU 上实现并行计算。而这和分块矩阵的乘法有着紧密的联系。
分块矩阵的定义
一个 维的矩阵 定义如下
现在将这个矩阵在行方向上划分为 个块,在列方向上划分为 个块。这样,每个块的维度为 ,其中
于是,分块后的矩阵可记为
其中,每个块 可以表示为
向量与分块矩阵相乘
设一个 维的向量
要使向量 右乘分块矩阵 ,需要相应地将向量 分成 个块,即
其中,每个块 可以表示为
因此,向量 右乘分块矩阵 可以表示为
矩阵与分块矩阵相乘
设一个 维的矩阵
仿照向量与分块矩阵相乘的方式,如果对矩阵 作 1D 切分,即
其中,每个分块矩阵的维度是 。那么,矩阵 右乘分块矩阵 可以表示为
如果对矩阵 作 2D 切分,在行方向上切分成 个块,在列方向上切分成 个块。这样,每个块的维度就变成 ,其中
那么矩阵 可表示为
这样,矩阵 右乘分块矩阵 可以表示为
分块矩阵乘法的一般规律
与普通的矩阵乘法相比,分块矩阵乘法多了一个步骤:分块累加。用规范化的数学语言描述,即
其中,矩阵 被 2D 切分为 份,矩阵 被 2D 切分为 份。
而至于向量与分块矩阵的乘法,以及 1D 和 2D 切分的情况,都属于上述一般规律的特例。是否需要将分块累加,取决于是否对 维进行切分。
图解分块矩阵乘法
设想有 2 个矩阵都被切分成 的块,每个块在 4 张不同的 GPU 上进行矩阵乘法运算。则每个 GPU 上的计算分别如下面 4 张图所示:
分块矩阵乘法的意义
矩阵乘法是线性代数中最重要的运算之一。在机器学习中,矩阵乘法也是经常用到的运算,最常见于 MLP 线性层。
而在实际的模型训练和推理系统中,模型参数和中间激活的张量可能非常大,而 GPU 显存空间有限。因此,我们需要将张量切分为多个块,以在 GPU 上实现并行计算。
而分块矩阵乘法的意义在于,分块意味着一个完整的张量可以进行切分,在 GPU 上进行并行的矩阵乘法计算从理论上说是可行的。
不过,如果我们将矩阵在 维上切分,最终的计算结果需要沿着 维进行聚合。这样的操作会带来额外的通信开销。
此外,最终的结果从整体上看是多个小块矩阵的拼接,这意味着每个 GPU 可以存储相应的不同的结果。至于是否存在通信开销,取决于后面的算子是否需要完整的张量。