0%

transformer架构

Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[J]. Advances in neural information processing systems, 2017, 30.

Transformer 是现代大模型最核心的基础架构之一。从 BERT、GPT 到如今的大语言模型,几乎都可以看作是在 Transformer 基础上的扩展和变体。

如果只用一句话概括它的核心思想,那就是:

  • 不再像 RNN 那样按时间步顺序处理序列
  • 而是用 self-attention 让每个 token 直接和序列中其他 token 建立联系

这种设计带来了两个非常重要的优点:

  1. 可以并行计算,训练效率更高
  2. 更容易建模长距离依赖关系

这篇文章会先把 Transformer 的整体架构讲清楚,再基于练习题中的推导,分析多头自注意力的参数量和计算复杂度。

一、为什么 Transformer 会出现

在 Transformer 之前,序列建模主要依赖两类方法:

  • RNN / LSTM / GRU
  • CNN 式序列模型

RNN 的问题在于,它天然是串行计算的。即使一个句子里后面的词和前面的词关系很远,信息也要一层层沿时间步传递。这会带来两个经典问题:

  • 长距离依赖难学
  • 训练和推理难以充分并行

CNN 虽然可以并行,但感受野需要通过堆叠层数逐渐扩大,对于全局依赖的建模仍然不够直接。

Transformer 的突破在于:让序列中每个位置都能直接“看见”其他所有位置,并根据相关性动态分配注意力权重。

因此,Transformer 的关键不是“记忆前一个状态”,而是“计算当前位置和其他位置之间的关系”。

二、整体架构

原始 Transformer 由两部分组成:

  • Encoder
  • Decoder

整体上是一个典型的 seq2seq 结构。

1. Encoder

Encoder 由多个相同的编码层堆叠而成。每一层通常包含两个子层:

  1. Multi-Head Self-Attention
  2. Position-wise Feed Forward Network

并且每个子层外面都有:

  • Residual Connection
  • Layer Normalization

因此一层 Encoder Block 的计算流程可以简化为:

$$
X \rightarrow \text{Multi-Head Attention} \rightarrow \text{Add & Norm} \rightarrow \text{FFN} \rightarrow \text{Add & Norm}
$$

2. Decoder

Decoder 同样由多个相同的解码层堆叠而成,但每层会多一个子层:

  1. Masked Multi-Head Self-Attention
  2. Cross-Attention
  3. Position-wise Feed Forward Network

其中:

  • 第一层 masked self-attention 保证当前位置只能看到自己和前面的 token,不能偷看未来信息
  • 第二层 cross-attention 让 decoder 去关注 encoder 输出的表示

因此 Decoder 比 Encoder 多了一步“从输入序列表示中取信息”的过程。

如果只讨论 GPT 这类自回归语言模型,那么通常只保留 Decoder 风格的 masked self-attention 结构;如果讨论 BERT,则更接近只用 Encoder。

三、输入表示

Transformer 的输入并不是单纯的 one-hot token,而是由两个部分相加得到:

  1. Token Embedding
  2. Positional Encoding

写成公式就是:

$$
X = E + P
$$

其中:

  • $E$ 表示 token embedding
  • $P$ 表示位置编码

这么做的原因很简单:attention 本身只关心“元素之间的关系”,它并不知道序列顺序。如果不加入位置信息,那么词序就会丢失。

位置编码为什么重要

例如下面两个句子:

  • dog bites man
  • man bites dog

如果模型完全不知道词的位置,这两个句子在 bag-of-words 意义上可能几乎一样,但语义完全不同。

因此 Transformer 必须额外注入顺序信息。

在原始论文中,位置编码使用的是正弦余弦函数:

$$
PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i / d_{\text{model}}}}\right)
$$

$$
PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i / d_{\text{model}}}}\right)
$$

它的直觉是:用不同频率的波形来编码位置,使得模型既能区分绝对位置,也能较容易推断相对位置信息。

四、自注意力机制

Transformer 的核心就是 attention,尤其是 self-attention。

1. Q、K、V 的含义

对于输入表示矩阵 $X \in \mathbb{R}^{n \times d_{\text{model}}}$,模型会分别通过三个线性映射得到:

$$
Q = XW^Q,\quad K = XW^K,\quad V = XW^V
$$

其中:

  • $Q$ 是 Query
  • $K$ 是 Key
  • $V$ 是 Value

直观地说:

  • Query 可以理解为“我现在想找什么信息”
  • Key 可以理解为“我这里提供什么信息”
  • Value 可以理解为“真正要取出的内容”

当某个位置的 Query 和另一个位置的 Key 越匹配时,它就应该更多地吸收对方的 Value。

2. Scaled Dot-Product Attention

单头注意力的公式是:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

这个公式可以分成三步理解。

第一步,计算相关性分数:

$$
S = QK^T
$$

这里的 $S_{ij}$ 表示第 $i$ 个 token 对第 $j$ 个 token 的关注程度。

第二步,进行缩放:

$$
\frac{QK^T}{\sqrt{d_k}}
$$

除以 $\sqrt{d_k}$ 的原因是,当向量维度变大时,点积的数值方差会变大,容易让 softmax 进入非常陡峭的区域,从而导致梯度不稳定。

第三步,做 softmax 并加权求和:

$$
A = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)
$$

$$
O = AV
$$

其中 $A$ 是注意力权重矩阵,$O$ 是当前注意力层的输出。

3. Self-Attention 的含义

当 $Q,K,V$ 都来自同一个输入序列时,这种 attention 就叫 self-attention。

它的意义是:序列中的每个 token 都可以根据当前任务需要,从整个序列的其他位置动态收集信息。

例如在句子:

  • The animal didn’t cross the street because it was tired.

中,词 it 的含义更接近 animal,而不是 street。self-attention 就允许模型直接给 animal 更高权重,而不需要像 RNN 那样跨很多步去传递状态。

五、多头注意力

单头注意力只能在一个表示子空间中建模关系,而多头注意力(Multi-Head Attention)希望让模型从多个不同角度同时观察序列。

设一共有 $h$ 个头,每个头的维度为:

$$
d_k = d_v = \frac{d_{\text{model}}}{h}
$$

第 $j$ 个头的计算为:

$$
\text{head}_j = \text{Attention}(Q_j, K_j, V_j)
$$

最后把所有头拼接起来:

$$
\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O
$$

这样做的直觉是:

  • 有的头可能关注语法关系
  • 有的头可能关注实体对齐
  • 有的头可能关注局部上下文
  • 有的头可能关注长距离依赖

虽然每个头的维度更小,但多个头并行后,整体表达能力反而更强。

六、前馈网络、残差连接与 LayerNorm

Transformer 不只有 attention。每个 block 中还有一个位置独立的前馈网络:

$$
\text{FFN}(x) = W_2 \sigma(W_1 x + b_1) + b_2
$$

这个网络对每个位置分别作用,但参数共享。它的作用可以理解为:

  • attention 负责“信息交换”
  • FFN 负责“对每个位置的表示做非线性变换”

此外,每个子层后面都会配合残差连接和 LayerNorm:

$$
\text{LayerNorm}(x + \text{Sublayer}(x))
$$

残差连接有助于深层网络训练,LayerNorm 有助于稳定表示分布。

七、为什么 Transformer 训练得更快

Transformer 相比 RNN 的一个重大优势是并行性。

RNN 的状态更新是:

  • 第 $t$ 步依赖第 $t-1$ 步

所以天然串行。

而 Transformer 在训练时可以一次性拿到整个序列,直接并行计算:

  • 所有位置的 $Q,K,V$
  • 所有位置之间的相关性矩阵
  • 所有位置的注意力输出

因此在 GPU 上训练效率会高很多。

不过,这种全局两两交互也带来了新的问题:attention 的复杂度会随着序列长度平方增长。

这正是练习题里重点分析的部分。

八、多头自注意力的参数量与计算复杂度

下面基于练习题中的设定,分析一个 encoder 中单个 multi-head self-attention block 的参数量与 FLOPs。

设:

  • 序列长度为 $n$
  • 模型维度为 $d_{\text{model}}$
  • 头数为 $h$
  • 每个头的维度为 $d_k = d_v = d_{\text{model}} / h$

输入张量形状为:

$$
(1, n, d_{\text{model}})
$$

这里默认 batch size 为 1。

1. 参数量

一个多头注意力块中,主要有四个线性投影矩阵:

  • $W^Q \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}$
  • $W^K \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}$
  • $W^V \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}$
  • $W^O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}$

因此:

$$
\mathrm{Params} = 4 d_{\text{model}}^2
$$

这个结果和头数 $h$ 无关。原因在于,多头拆分本质上只是 reshape,参数仍然集中存在于这四个总投影矩阵中。

2. FLOPs 推导

下面只看前向传播的主要计算量,并采用题目中的近似规则:

  • 一个 $m \times k$ 与 $k \times p$ 的矩阵乘法,记为 $2mkp$ FLOPs
  • softmax、缩放等逐元素操作近似记为低阶项

第一步:计算 Q、K、V 投影

输入 $X$ 的形状是 $(n, d_{\text{model}})$,乘上投影矩阵 $(d_{\text{model}}, d_{\text{model}})$。

每次投影代价为:

$$
2 n d_{\text{model}}^2
$$

三次投影合计:

$$
6 n d_{\text{model}}^2
$$

第二步:计算注意力分数 $QK^T$

对于每个头,$Q_{\text{head}}$ 的形状为 $(n, d_k)$,$K_{\text{head}}^T$ 的形状为 $(d_k, n)$。

每个头的代价为:

$$
2 n^2 d_k
$$

一共有 $h$ 个头,因此总代价为:

$$
h \cdot 2 n^2 d_k = h \cdot 2 n^2 \frac{d_{\text{model}}}{h} = 2 n^2 d_{\text{model}}
$$

第三步:缩放和 softmax

这一部分作用在 $(n,n,h)$ 的注意力分数张量上,近似记作:

$$
O(n^2 h)
$$

第四步:注意力权重与 V 相乘

对于每个头,注意力矩阵形状为 $(n,n)$,$V_{\text{head}}$ 形状为 $(n,d_k)$。

每个头的代价为:

$$
2 n^2 d_k
$$

所有头加起来:

$$
2 n^2 d_{\text{model}}
$$

第五步:输出投影

拼接所有头之后,输出形状回到 $(n,d_{\text{model}})$,再乘上 $W^O$,代价为:

$$
2 n d_{\text{model}}^2
$$

3. 总 FLOPs

把各部分加起来:

$$
6 n d_{\text{model}}^2 + 2 n^2 d_{\text{model}} + O(n^2 h) + 2 n^2 d_{\text{model}} + 2 n d_{\text{model}}^2
$$

整理得到:

$$
8 n d_{\text{model}}^2 + 4 n^2 d_{\text{model}} + O(n^2 h)
$$

如果忽略低阶项,那么主要复杂度就是:

$$
8 n d_{\text{model}}^2 + 4 n^2 d_{\text{model}}
$$

4. 哪一部分是瓶颈

从上式可以看出,多头自注意力中主要有两类大头开销:

  1. 线性投影相关的 $n d_{\text{model}}^2$
  2. 注意力相关的 $n^2 d_{\text{model}}$

因此瓶颈取决于:

  • 是 $d_{\text{model}}$ 更大
  • 还是序列长度 $n$ 更大

当 $d_{\text{model}} \gg n$ 时

线性投影更贵,因为:

$$
n d_{\text{model}}^2
$$

会显著大于:

$$
n^2 d_{\text{model}}
$$

例如中等长度序列、但模型宽度很大时,Q/K/V 和输出投影往往是主要成本。

当 $n \gg d_{\text{model}}$ 时

注意力分数计算和 attention 与 V 的矩阵乘法会成为瓶颈,因为它们都含有:

$$
n^2 d_{\text{model}}
$$

这也是为什么超长上下文会让 Transformer 的计算和显存开销迅速爆炸。

所以 Transformer 的核心效率问题,本质上就在于:

  • 模型宽度决定了投影层成本
  • 序列长度决定了 attention 的平方复杂度成本

九、Transformer 的优点与局限

优点

  • 可以高效并行训练
  • 更容易建模长距离依赖
  • 架构统一,扩展性非常强
  • self-attention 的交互方式非常灵活

局限

  • attention 对长序列是平方复杂度
  • 显存消耗大
  • 对训练数据和算力依赖较强
  • 原始位置编码方式并不一定最适合所有任务

也正因为这些问题,后续出现了大量改进方向,例如:

  • 稀疏注意力
  • 线性注意力
  • 相对位置编码
  • FlashAttention
  • MoE 结构

但无论具体怎么改,很多工作依然是在 Transformer 主干上演化。

十、小结

Transformer 的核心不是“把 RNN 换成更深的网络”,而是重新定义了序列建模的基本方式:

  • 用 self-attention 代替递归状态传递
  • 用多头机制在不同子空间中并行建模关系
  • 用残差连接、LayerNorm 和 FFN 组成稳定的堆叠结构

从这次练习题里的复杂度推导也可以看到,Transformer 的成功并不意味着它没有代价。它一方面通过并行计算大幅提升了训练效率,另一方面也因为 attention 的全局两两交互,引入了明显的平方复杂度瓶颈。

所以理解 Transformer,不能只停留在“attention 很强”这句话上,而要真正理解三件事:

  1. 它为什么能替代 RNN
  2. 它的每个模块到底在做什么
  3. 它的计算瓶颈到底来自哪里

这三点想明白之后,再去看 BERT、GPT、ViT 或各种高效注意力变体,就会容易很多。