本节重点阐述矩阵乘法的计算量分析,这是后续 Transformer 模型的计算量分析的基础。
假设有两个向量 \(x\)、\(y\) 和两个矩阵 \(A\)、\(B\) 有如下的形状:
向量/矩阵 | 形状 |
---|---|
\(x\) | \([K]\) |
\(y\) | \([K]\) |
\(A\) | \([M, K]\) |
\(B\) | \([K, N]\) |
那么 向量/矩阵 之间的浮点计算量如下:
高维矩阵的计算量分析更加复杂一些,因为其中的维度分为三种情况:
为了更好地理解这三种维度,我们首先来看一个简单的例子:
假设我们有一个张量 \(A\),其形状为 \((\textcolor{blue}{B}, \textcolor{green}{M}, \textcolor{red}{K})\),以及另一个张量 \(B\),其形状为 \((\textcolor{blue}{B}, \textcolor{red}{K}, \textcolor{green}{N})\)。我们希望计算 \(C = A \times B\),在这种情况下:
结果张量 \(C\) 的形状为 \((\textcolor{blue}{B}, \textcolor{green}{M}, \textcolor{green}{N})\)。可以观察到,收缩维度被消去,批处理维度只保留一份,自由维度都被保留。矩阵乘法 \(A \times B\) 的计算量为 \(2\textcolor{blue}{B}\textcolor{green}{MN}\textcolor{red}{K}\)。
在推导大模型推理阶段的计算量之前,首先需要引入一系列符号用于表示推理过程中的关键概念:
符号 | 含义 |
---|---|
\(B\) | batch size |
\(L\) | number of layers |
\(T\) | sequence length (query) |
\(S\) | sequence length (key/value) |
\(V\) | vocab |
\(D\) | dimension of model (embedding size) |
\(F\) | MLP hidden dimension |
\(H\) | attention head dimension |
\(N\) | number of query heads |
\(K\) | number of key/value heads |
\(G\) | q heads per kv head |
在实践中,常常设置 \(NH=D\),但是严格上来说,两者的 dimension 可以不一致,所以需要区分开来。
在 MHA 中,query 的多头数量和 key/value 一致,都设置为 \(H\)。但是在 MQA 和 GQA 中,key/value 的头数量比 query 更少,上表中的 \(K\) 和 \(G\) 参数的引入也是为了方便对于这两种 attention 计算情况的论证:
\(G\) 的含义是一个 key/value 的头被几个 query 的头共用,所以 \(K \times G = N\)。
Embedding 本质是一个查表操作(look-up),不是 gemm,计算量相对小。
该阶段计算量很小,几乎可以忽略不计。
Attention 阶段核心包含以下几个数学公式:
Attention 计算可以分为三大部分:
其中 gemm 的计算量和访存量如下表所示:
operation | inference FLOPs | params | output shape |
---|---|---|---|
\(A[B,T,\textcolor{red}{D}] \cdot W_Q[\textcolor{red}{D},N,H]\) | \(2BTDNH\) | \(DNH\) | \(Q[B,T,D,H]\) |
\(A[B,T,\textcolor{red}{D}] \cdot W_K[\textcolor{red}{D},K,H]\) | \(2BTDKH\) | \(DKH\) | \(K[B,T,K,H]\) |
\(A[B,T,\textcolor{red}{D}] \cdot W_V[\textcolor{red}{D},K,H]\) | \(2BTDKH\) | \(DKH\) | \(V[B,T,K,H]\) |
\(A[B,T,\textcolor{red}{N,H}] \cdot W_O[\textcolor{red}{N,H},D]\) | \(2BTDNH\) | \(DNH\) | \(\text{Z}[B,T,D]\) |
其中 attention score 的计算量如下表所示:
operation | inference FLOPs | output shape |
---|---|---|
\(Q[\textcolor{blue}{B},T,\textcolor{blue}{K},G,\textcolor{red}{H}] \cdot K[\textcolor{blue}{B},S,\textcolor{blue}{K},\textcolor{red}{H}]\) | \(2BTSKGH=2BTSNH\) | \(\text{score}[B,T,S,K,G]=[B,T,S,N]\) |
\(\text{softmax}_{S}\ L[B,T,S,K,G]\) | \(O(BTSKG)=O(BTSN)\) | |
\(S[\textcolor{blue}{B},T,\textcolor{red}{S},\textcolor{blue}{K},G] \cdot V[\textcolor{blue}{B},\textcolor{red}{S},\textcolor{blue}{K},H]\) | \(2BTSKGH=2BTSNH\) | \(Y[B,T,K,G,H]=[B,T,N,H]\) |
根据以上推导,可以得到以下结论:
为了方便理解,我们考虑以下 \(B=1\)、\(N=1\) 的情况,计算过程为如下公式:
\[ \begin{aligned} Y&=\underbrace{\begin{bmatrix} \alpha_{11} & \alpha_{12} & \cdots & \alpha_{1s}\\ \alpha_{21} & \alpha_{22} & \cdots & \alpha_{2s}\\ \vdots & \vdots & \ddots & \vdots \\ \alpha_{t1} & \alpha_{t2} & \cdots & \alpha_{ts} \end{bmatrix}}_{\displaystyle \alpha=\operatorname{Softmax}\!\left(\frac{QK^{\top}}{\sqrt{d_k}}\right)} \begin{bmatrix} v_1\\ v_2\\ \vdots\\ v_s \end{bmatrix} \\[4pt] &=\begin{bmatrix} \alpha_{11}v_1+\alpha_{12}v_2+\cdots+\alpha_{1s}v_s\\ \alpha_{21}v_1+\alpha_{22}v_2+\cdots+\alpha_{2s}v_s\\ \vdots\\ \alpha_{t1}v_1+\alpha_{t2}v_2+\cdots+\alpha_{ts}v_s \end{bmatrix}. \end{aligned} \]
其中 \(\alpha_{ij}\) 是当前 token 和先前每一个 token 的注意力得分,通过以下方式计算出来,注意 Softmax 是按照行作用的:
\[ s_{ij}=\frac{q_i\cdot k_j}{\sqrt{d_k}},\qquad \alpha_{ij}=\frac{e^{s_{ij}}}{\sum_{t=1}^{n} e^{s_{it}}} \]
\[ QK^{T} = \begin{bmatrix} q_1 \\ q_2 \\ \vdots \\ q_t \end{bmatrix} \begin{bmatrix} k_1^T & k_2^T & \cdots & k_s^T \end{bmatrix} = \begin{bmatrix} q_1 \cdot k_1^T & q_1 \cdot k_2^T & q_1 \cdot k_3^T & q_1 \cdot k_s^T \\ q_2 \cdot k_1^T & q_2 \cdot k_2^T & q_2 \cdot k_3^T & q_2 \cdot k_s^T \\ \vdots & \vdots & \ddots & \vdots \\ q_t \cdot k_1^T & q_t \cdot k_2^T & q_t \cdot k_3^T & q_t \cdot k_s^T \\ \end{bmatrix} \xrightarrow{\text{softmax 逐行归一化}} \begin{bmatrix} \alpha_{11} & \alpha_{12} & \cdots & \alpha_{1s}\\ \alpha_{21} & \alpha_{22} & \cdots & \alpha_{2s}\\ \vdots & \vdots & \ddots & \vdots \\ \alpha_{t1} & \alpha_{t2} & \cdots & \alpha_{ts} \end{bmatrix} \]
注意力得分矩阵的 shape 为 \([T, S]\),行长度就是 query length,列长度就是 kv length,每行就是一个 token 和之前 token 的注意力打分,还需要乘上对应的 \(v\) 向量。再乘以 \(V\) 的时候收缩的维度是在行上,所以 contracting dimension 是 \(S\)。
首先说一说 MLP,MLP 在当前 transformer 的模型中有两种常见实现方式,一种是 up/down,另一种是 in1/in2/out。
第一种 up/down 就是经典的 transformer 论文中提到的两层线性层,包含三个数学公式:
第二种方式是 in1/in2/out,两个 in 是并行的线性映射,一个负责主通道(值),一个负责门控(控制开关)。比传统 up/down 更灵活,计算量略多,但性能通常更好。
现在 transformer 架构通常使用第二种方式,几个核心的操作都是 gemm 操作,其计算量如下:
operation | inference FLOPs | params |
---|---|---|
\(A[B,T,\textcolor{red}{D}] \cdot W_{in1}[\textcolor{red}{D},F]\) | \(2BTDF\) | \(DF\) |
\(A[B,T,\textcolor{red}{D}] \cdot W_{in2}[\textcolor{red}{D},F]\) | \(2BTDF\) | \(DF\) |
\(\sigma(A_{in1})[B,T,F] * A_{in2}[B,T,F]\) | \(O(BTF)\) | |
\(A[B,T,\textcolor{red}{F}] \cdot W_{out}[\textcolor{red}{F},D]\) | \(2BTDF\) | \(DF\) |
如果是使用 MOE 的模型,主要包含以下几个数学公式: