Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[J]. Advances in neural information processing systems, 2017, 30.
Transformer 是现代大模型最核心的基础架构之一。从 BERT、GPT 到如今的大语言模型,几乎都可以看作是在 Transformer 基础上的扩展和变体。
如果只用一句话概括它的核心思想,那就是:
- 不再像 RNN 那样按时间步顺序处理序列
- 而是用 self-attention 让每个 token 直接和序列中其他 token 建立联系
这种设计带来了两个非常重要的优点:
- 可以并行计算,训练效率更高
- 更容易建模长距离依赖关系
这篇文章会先把 Transformer 的整体架构讲清楚,再基于练习题中的推导,分析多头自注意力的参数量和计算复杂度。
一、为什么 Transformer 会出现
在 Transformer 之前,序列建模主要依赖两类方法:
- RNN / LSTM / GRU
- CNN 式序列模型
RNN 的问题在于,它天然是串行计算的。即使一个句子里后面的词和前面的词关系很远,信息也要一层层沿时间步传递。这会带来两个经典问题:
- 长距离依赖难学
- 训练和推理难以充分并行
CNN 虽然可以并行,但感受野需要通过堆叠层数逐渐扩大,对于全局依赖的建模仍然不够直接。
Transformer 的突破在于:让序列中每个位置都能直接“看见”其他所有位置,并根据相关性动态分配注意力权重。
因此,Transformer 的关键不是“记忆前一个状态”,而是“计算当前位置和其他位置之间的关系”。
二、整体架构
原始 Transformer 由两部分组成:
- Encoder
- Decoder
整体上是一个典型的 seq2seq 结构。
1. Encoder
Encoder 由多个相同的编码层堆叠而成。每一层通常包含两个子层:
- Multi-Head Self-Attention
- Position-wise Feed Forward Network
并且每个子层外面都有:
- Residual Connection
- Layer Normalization
因此一层 Encoder Block 的计算流程可以简化为:
$$
X \rightarrow \text{Multi-Head Attention} \rightarrow \text{Add & Norm} \rightarrow \text{FFN} \rightarrow \text{Add & Norm}
$$
2. Decoder
Decoder 同样由多个相同的解码层堆叠而成,但每层会多一个子层:
- Masked Multi-Head Self-Attention
- Cross-Attention
- Position-wise Feed Forward Network
其中:
- 第一层 masked self-attention 保证当前位置只能看到自己和前面的 token,不能偷看未来信息
- 第二层 cross-attention 让 decoder 去关注 encoder 输出的表示
因此 Decoder 比 Encoder 多了一步“从输入序列表示中取信息”的过程。
如果只讨论 GPT 这类自回归语言模型,那么通常只保留 Decoder 风格的 masked self-attention 结构;如果讨论 BERT,则更接近只用 Encoder。
三、输入表示
Transformer 的输入并不是单纯的 one-hot token,而是由两个部分相加得到:
- Token Embedding
- Positional Encoding
写成公式就是:
$$
X = E + P
$$
其中:
- $E$ 表示 token embedding
- $P$ 表示位置编码
这么做的原因很简单:attention 本身只关心“元素之间的关系”,它并不知道序列顺序。如果不加入位置信息,那么词序就会丢失。
位置编码为什么重要
例如下面两个句子:
- dog bites man
- man bites dog
如果模型完全不知道词的位置,这两个句子在 bag-of-words 意义上可能几乎一样,但语义完全不同。
因此 Transformer 必须额外注入顺序信息。
在原始论文中,位置编码使用的是正弦余弦函数:
$$
PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i / d_{\text{model}}}}\right)
$$
$$
PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i / d_{\text{model}}}}\right)
$$
它的直觉是:用不同频率的波形来编码位置,使得模型既能区分绝对位置,也能较容易推断相对位置信息。
四、自注意力机制
Transformer 的核心就是 attention,尤其是 self-attention。
1. Q、K、V 的含义
对于输入表示矩阵 $X \in \mathbb{R}^{n \times d_{\text{model}}}$,模型会分别通过三个线性映射得到:
$$
Q = XW^Q,\quad K = XW^K,\quad V = XW^V
$$
其中:
- $Q$ 是 Query
- $K$ 是 Key
- $V$ 是 Value
直观地说:
- Query 可以理解为“我现在想找什么信息”
- Key 可以理解为“我这里提供什么信息”
- Value 可以理解为“真正要取出的内容”
当某个位置的 Query 和另一个位置的 Key 越匹配时,它就应该更多地吸收对方的 Value。
2. Scaled Dot-Product Attention
单头注意力的公式是:
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$
这个公式可以分成三步理解。
第一步,计算相关性分数:
$$
S = QK^T
$$
这里的 $S_{ij}$ 表示第 $i$ 个 token 对第 $j$ 个 token 的关注程度。
第二步,进行缩放:
$$
\frac{QK^T}{\sqrt{d_k}}
$$
除以 $\sqrt{d_k}$ 的原因是,当向量维度变大时,点积的数值方差会变大,容易让 softmax 进入非常陡峭的区域,从而导致梯度不稳定。
第三步,做 softmax 并加权求和:
$$
A = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)
$$
$$
O = AV
$$
其中 $A$ 是注意力权重矩阵,$O$ 是当前注意力层的输出。
3. Self-Attention 的含义
当 $Q,K,V$ 都来自同一个输入序列时,这种 attention 就叫 self-attention。
它的意义是:序列中的每个 token 都可以根据当前任务需要,从整个序列的其他位置动态收集信息。
例如在句子:
- The animal didn’t cross the street because it was tired.
中,词 it 的含义更接近 animal,而不是 street。self-attention 就允许模型直接给 animal 更高权重,而不需要像 RNN 那样跨很多步去传递状态。
五、多头注意力
单头注意力只能在一个表示子空间中建模关系,而多头注意力(Multi-Head Attention)希望让模型从多个不同角度同时观察序列。
设一共有 $h$ 个头,每个头的维度为:
$$
d_k = d_v = \frac{d_{\text{model}}}{h}
$$
第 $j$ 个头的计算为:
$$
\text{head}_j = \text{Attention}(Q_j, K_j, V_j)
$$
最后把所有头拼接起来:
$$
\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O
$$
这样做的直觉是:
- 有的头可能关注语法关系
- 有的头可能关注实体对齐
- 有的头可能关注局部上下文
- 有的头可能关注长距离依赖
虽然每个头的维度更小,但多个头并行后,整体表达能力反而更强。
六、前馈网络、残差连接与 LayerNorm
Transformer 不只有 attention。每个 block 中还有一个位置独立的前馈网络:
$$
\text{FFN}(x) = W_2 \sigma(W_1 x + b_1) + b_2
$$
这个网络对每个位置分别作用,但参数共享。它的作用可以理解为:
- attention 负责“信息交换”
- FFN 负责“对每个位置的表示做非线性变换”
此外,每个子层后面都会配合残差连接和 LayerNorm:
$$
\text{LayerNorm}(x + \text{Sublayer}(x))
$$
残差连接有助于深层网络训练,LayerNorm 有助于稳定表示分布。
七、为什么 Transformer 训练得更快
Transformer 相比 RNN 的一个重大优势是并行性。
RNN 的状态更新是:
- 第 $t$ 步依赖第 $t-1$ 步
所以天然串行。
而 Transformer 在训练时可以一次性拿到整个序列,直接并行计算:
- 所有位置的 $Q,K,V$
- 所有位置之间的相关性矩阵
- 所有位置的注意力输出
因此在 GPU 上训练效率会高很多。
不过,这种全局两两交互也带来了新的问题:attention 的复杂度会随着序列长度平方增长。
这正是练习题里重点分析的部分。
八、多头自注意力的参数量与计算复杂度
下面基于练习题中的设定,分析一个 encoder 中单个 multi-head self-attention block 的参数量与 FLOPs。
设:
- 序列长度为 $n$
- 模型维度为 $d_{\text{model}}$
- 头数为 $h$
- 每个头的维度为 $d_k = d_v = d_{\text{model}} / h$
输入张量形状为:
$$
(1, n, d_{\text{model}})
$$
这里默认 batch size 为 1。
1. 参数量
一个多头注意力块中,主要有四个线性投影矩阵:
- $W^Q \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}$
- $W^K \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}$
- $W^V \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}$
- $W^O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}$
因此:
$$
\mathrm{Params} = 4 d_{\text{model}}^2
$$
这个结果和头数 $h$ 无关。原因在于,多头拆分本质上只是 reshape,参数仍然集中存在于这四个总投影矩阵中。
2. FLOPs 推导
下面只看前向传播的主要计算量,并采用题目中的近似规则:
- 一个 $m \times k$ 与 $k \times p$ 的矩阵乘法,记为 $2mkp$ FLOPs
- softmax、缩放等逐元素操作近似记为低阶项
第一步:计算 Q、K、V 投影
输入 $X$ 的形状是 $(n, d_{\text{model}})$,乘上投影矩阵 $(d_{\text{model}}, d_{\text{model}})$。
每次投影代价为:
$$
2 n d_{\text{model}}^2
$$
三次投影合计:
$$
6 n d_{\text{model}}^2
$$
第二步:计算注意力分数 $QK^T$
对于每个头,$Q_{\text{head}}$ 的形状为 $(n, d_k)$,$K_{\text{head}}^T$ 的形状为 $(d_k, n)$。
每个头的代价为:
$$
2 n^2 d_k
$$
一共有 $h$ 个头,因此总代价为:
$$
h \cdot 2 n^2 d_k = h \cdot 2 n^2 \frac{d_{\text{model}}}{h} = 2 n^2 d_{\text{model}}
$$
第三步:缩放和 softmax
这一部分作用在 $(n,n,h)$ 的注意力分数张量上,近似记作:
$$
O(n^2 h)
$$
第四步:注意力权重与 V 相乘
对于每个头,注意力矩阵形状为 $(n,n)$,$V_{\text{head}}$ 形状为 $(n,d_k)$。
每个头的代价为:
$$
2 n^2 d_k
$$
所有头加起来:
$$
2 n^2 d_{\text{model}}
$$
第五步:输出投影
拼接所有头之后,输出形状回到 $(n,d_{\text{model}})$,再乘上 $W^O$,代价为:
$$
2 n d_{\text{model}}^2
$$
3. 总 FLOPs
把各部分加起来:
$$
6 n d_{\text{model}}^2 + 2 n^2 d_{\text{model}} + O(n^2 h) + 2 n^2 d_{\text{model}} + 2 n d_{\text{model}}^2
$$
整理得到:
$$
8 n d_{\text{model}}^2 + 4 n^2 d_{\text{model}} + O(n^2 h)
$$
如果忽略低阶项,那么主要复杂度就是:
$$
8 n d_{\text{model}}^2 + 4 n^2 d_{\text{model}}
$$
4. 哪一部分是瓶颈
从上式可以看出,多头自注意力中主要有两类大头开销:
- 线性投影相关的 $n d_{\text{model}}^2$
- 注意力相关的 $n^2 d_{\text{model}}$
因此瓶颈取决于:
- 是 $d_{\text{model}}$ 更大
- 还是序列长度 $n$ 更大
当 $d_{\text{model}} \gg n$ 时
线性投影更贵,因为:
$$
n d_{\text{model}}^2
$$
会显著大于:
$$
n^2 d_{\text{model}}
$$
例如中等长度序列、但模型宽度很大时,Q/K/V 和输出投影往往是主要成本。
当 $n \gg d_{\text{model}}$ 时
注意力分数计算和 attention 与 V 的矩阵乘法会成为瓶颈,因为它们都含有:
$$
n^2 d_{\text{model}}
$$
这也是为什么超长上下文会让 Transformer 的计算和显存开销迅速爆炸。
所以 Transformer 的核心效率问题,本质上就在于:
- 模型宽度决定了投影层成本
- 序列长度决定了 attention 的平方复杂度成本
九、Transformer 的优点与局限
优点
- 可以高效并行训练
- 更容易建模长距离依赖
- 架构统一,扩展性非常强
- self-attention 的交互方式非常灵活
局限
- attention 对长序列是平方复杂度
- 显存消耗大
- 对训练数据和算力依赖较强
- 原始位置编码方式并不一定最适合所有任务
也正因为这些问题,后续出现了大量改进方向,例如:
- 稀疏注意力
- 线性注意力
- 相对位置编码
- FlashAttention
- MoE 结构
但无论具体怎么改,很多工作依然是在 Transformer 主干上演化。
十、小结
Transformer 的核心不是“把 RNN 换成更深的网络”,而是重新定义了序列建模的基本方式:
- 用 self-attention 代替递归状态传递
- 用多头机制在不同子空间中并行建模关系
- 用残差连接、LayerNorm 和 FFN 组成稳定的堆叠结构
从这次练习题里的复杂度推导也可以看到,Transformer 的成功并不意味着它没有代价。它一方面通过并行计算大幅提升了训练效率,另一方面也因为 attention 的全局两两交互,引入了明显的平方复杂度瓶颈。
所以理解 Transformer,不能只停留在“attention 很强”这句话上,而要真正理解三件事:
- 它为什么能替代 RNN
- 它的每个模块到底在做什么
- 它的计算瓶颈到底来自哪里
这三点想明白之后,再去看 BERT、GPT、ViT 或各种高效注意力变体,就会容易很多。