- \([\mathbf{q}_{t,1};
\mathbf{q}_{t,2}; ...; \mathbf{q}_{t,n_h}] =
\mathbf{q}_t = W^{Q} \mathbf{h}_{t}\)
- \([\mathbf{k}_{t,1};
\mathbf{k}_{t,2}; ...; \mathbf{k}_{t,n_h}] =
\mathbf{k}_t = W^{K} \mathbf{h}_{t}\)
- \([\mathbf{v}_{t,1};
\mathbf{v}_{t,2}; ...; \mathbf{v}_{t,n_h}] =
\mathbf{v}_t = W^{V} \mathbf{h}_{t}\)
- \(\mathbf{o}_{t,i} =
\sum_{j=1}^{t}
\text{Softmax}_j\left(\frac{\mathbf{q}_{t,i}
\mathbf{k}_{j,i}^\top}{\sqrt{d_h}}\right)
\mathbf{v}_{j,i}\)
- \(i\) 表示第 \(i\) 个 head
- \(t\) 当时计算第
\(t\) 个 token 的
attention
- \(\mathbf{u}_t =
W^O[\mathbf{o}_{t,1}; \mathbf{o}_{t,2}; ...;
\mathbf{o}_{t,n_h}]\)
上述公式是 q_len = 1 的场景,并且考察了多头,head
的下标用变量 \(i\)
表示。
更简单的方式是去除多头的符号,因为 head,q_len
是可以并行的扩展。
- \(\mathbf{o}_{t} =
\sum_{j=1}^{t}
\text{Softmax}_j\left(\frac{\mathbf{q}_{t}
\mathbf{k}_{j}^\top}{\sqrt{d_h}}\right)
\mathbf{v}_{j}\)