Skip to content

Transformer

RNN 的主要问题:

  • 梯度消失/爆炸:长距离依赖难以学习
  • 顺序计算:无法并行处理序列
  • 信息瓶颈:最后时刻隐藏状态需承载全部信息

Transformer 的改进:

  • 并行计算:同时处理整个序列
  • 自注意力机制:直接建立任意位置间的联系
  • 位置编码:显式注入位置信息

结构

Transformer

图:Transformer 单元

Transformer

图:Transformer 的详细结构

  • 输入
    • 编码器输入
    • 解码器输入
  • 输出
    • 线性层
    • Softmax 层
  • 编码器
    • 由 N 个编码器层堆叠而成
    • 每个编码器层由两个子层连接结构组成
    • 第一个子层连接结构包括一个多头自注意力子层规范化层以及一个残差连接
    • 第二个子层连接结构包括一个前馈全连接子层规范化层以及一个残差连接
  • 解码器
    • 由 N 个解码器层堆叠而成
    • 每个解码器层由三个子层连接结构组成
    • 第一个子层连接结构包括一个带掩码的-多头自注意力子层规范化层以及一个残差连接
    • 第二个子层连接结构包括一个多头注意力子层(编码器到解码器)和规范化层以及一个残差连接
    • 第三个子层连接结构包括一个前馈全连接子层规范化层以及一个残差连接

核心组件

自注意力机制(Self-Attention)

在传统的神经网络处理序列时,模型只能一步步按顺序处理,难以捕捉长距离依赖关系。自注意力机制就是为了让序列中的每个元素都能直接与序列中所有其他元素进行交互,无论它们直接的距离多远。

定义:

符号维度含义
Xn×d输入矩阵(n=序列长度,d=特征维度)
Qn×dkQuery 矩阵(查询向量)
Kn×dkKey 矩阵(键向量)
Vn×dvValue 矩阵(值向量)
WQ,WK,WVd×dk/dv可学习参数矩阵
Attention(Q,K,V)=Softmax(QKTdk)V

TIP

  • Q 表示当前需要关注的信息或问题,用于确定输入序列中哪些部分与当前任务相关
  • K 用于匹配查询,通过计算相似度判断输入序列中哪些元素与查询匹配
  • V 存储实际信息

推导过程

将输入转换为 Query、Key、Value:

Q=XWQ,K=XWK,V=XWV

计算注意力分数:

Scores=QKTdk

生成注意力权重矩阵:

A=Softmax(Scores)

得到最终注意力输出:

Output=AV

带掩码自注意力层(Masked Multi-head attention)

编码时,对于 t 时刻的预测,我们知道 x1,x2,,xt,xt+1,,xT 全部的信息。

解码时,对于 t 时刻的预测,我们仅知道 x1,x2,,xt1 的信息。看不到后续的信息,因此需要将后续的信息遮掩起来。

Attention(Q,K,V)=Softmax(QKTMdk)V

多头注意力(Multi-Head Attention)

Transformer

MultiHead(Q,K,V)=Concat(head1,,headh)WO

其中:

headi=Attention(QWiQ,KWiK,VWiV)

位置编码(Positional Encoding)

RNNLSTM 等顺序算法不同,Transformer 没有内置机制来捕获句子中单词的相对位置,所以在 Transformerencoderdecoder 的输入层中,使用了 Positional Encoding,使得最终的输入满足:

input=input_embedding+positional_encoding

原始正弦编码公式:

PE(pos,2i)=sin(pos100002i/d)PE(pos,2i+1)=cos(pos100002i/d)

前馈网络(Feed Forward Network)

包括两个线性变换+ReLU 激活:

FFN(x)=ReLU(xW1+b1)W2+b2

计算复杂度

当输入批次大小为 b,序列长度为 N,词向量的维度(隐藏层的维度)为 d 时,l 层 transformer 的计算复杂度:

Self-Attention 层

FLOPs(Self-Attention)=8bNd2+4bN2d
  1. 计算 QKV

输入输出

[b,N,d]×[d,d][b,N,d]

计算量为:

FLOPs=3QKVb(Ndd[N,d]×[d,d]乘法+Ndd[N,d]×[d,d]加法)=6bNd2

TIP

矩阵加法运算考虑偏差 bias计算量就是 Ndd,如果不考虑偏差就是 Nd(d1),但这个 1 一般忽略不计。

  1. 计算 QKT

输入输出

[b,h,N,dk]×[b,h,dk,N][b,h,N,N]

h 为注意力头数,dk 为每个头的维度,hdk=d

FLOPs=bh(N2dk+N2dk)=2bN2d
  1. Softmax 与加权求和

Softmax 计算量较小,通常忽略。

输入输出

[b,h,N,N]×[b,h,N,dk][b,h,N,dk]
FLOPs=bh(NdkN+NdkN)=2bN2d
  1. 输出投影

线性变换将结果映射回 n 维:

输入输出

[b,N,d]×[d,d][b,N,d]
FLOPs=2bNd2

MLP 层

FLOPs(MLP)=16bNd2
  1. 线性层(扩展层)

输入输出

[b,N,d]×[d,4d][b,N,4d]
FLOPs=8bNd2
  1. 线性层(压缩层)

输入输出

[b,N,4d]×[4d,d][b,N,d]
FLOPs=8bNd2

logits

Logits 层是将最终的 Transformer 隐藏层输出(维度 d)映射到词表大小 V,即一个线性投影:

输入输出

[b,N,d]×[d,V][b,N,V]
FLOPs(logits)=2bNdV

总的计算复杂度

FLOPs(Transformer)=l(24bNd2+4bN2d)+2bNdV

空间复杂度

大模型在训练过程中通常采用混合精度训练,中间激活值一般是 float16 或者 bfloat16 数据类型的。在分析中间激活的显存占用时,假设中间激活值是以 float16 或 bfloat16 数据格式来保存的,每个元素占了 2 个 bytes,dropout 操作的 mask 矩阵,每个元素只占 1 个 bytes。需要保存的中间激活占用显存大小计算如下:

Self-Attention 层

  1. QKV 共享一个输入 X,则显存占用为 2bNd
  2. 对于 QKT,两个张量形状都是 [b,N,d],显存占用为 4bNd
  3. 对于 Softmax,函数输入 QKT 形状为 [b,h,N,N],显存占用为 2bN2h
  4. 计算完 Softmax,会进行 dropout,需要保存一个 mask 矩阵,其形状与 QKT 相同,显存占用为 bN2h
  5. 计算 ScoresV,二者占用显存大小为 2bN2h+2bNd
  6. 计算输出映射和一个 dropout 操作,二者占用显存大小为 2bNd+bNd

综上,Self-Attention 层的显存占用为 11bNd+5bN2a

MLP 层

  1. 第一个线性层的输入占用显存 2bNd
  2. 激活函数的输入占用显存 8bNd
  3. 第二个线性层的输入占用显存 8bNd
  4. 最后的 dropout 操作需要保存的 mask 矩阵占用显存 bNd

综上,MLP 层的显存占用为 19bNd

LN

Self-Attention 层和 MLP 层分别对应了一个 LN,其输入占用显存为 2bNd+2bNd

总的空间复杂度

l(34bNd+5bN2h)

问题

Transformer 的计算复杂度为:l(24bNd2+4bN2d)+2bNdV,需要保存的中间激活占用显存大小为:l(34bNd+5bN2h),即 Transformer 模型的计算量和储存复杂度随着序列长度 N 呈二次方增长。

可以注意到,4bN2d5bN2h 均产生于 Self-Attention 层。