在 Transformer 的自注意力机制中,K(键矩阵)、Q(查询矩阵)、V(值矩阵) 的形状推导与输入 embedding 的维度直接相关,具体流程和 N2 复杂度的来源如下:
基础输入与矩阵形状定义设输入序列长度为
N,每个 token 的 embedding 维度为
dmodel,则输入 embedding 矩阵
X 的形状为
[N,dmodel]。
自注意力会用 3 个可学习的参数矩阵
WQ,WK,WV(形状均为
[dmodel,dk],其中
dk 为键 / 查询的维度)对输入进行线性变换,得到:
- 查询矩阵 Q=XWQ,形状 [N,dk]
- 键矩阵 K=XWK,形状 [N,dk]
- 值矩阵 V=XWV,形状 [N,dv](通常 dv=dk)
N2 复杂度的核心来源:注意力分数计算自注意力的核心是计算
查询与所有键的相似度,即注意力分数矩阵
A:
A=QK⊤其中
K⊤ 是
K 的转置,形状为
[dk,N]。
根据矩阵乘法规则,
[N,dk]×[dk,N] 的运算结果是一个
[N,N] 的方阵 —— 这个方阵的每个元素
Aij 对应第
i 个 token 的查询与第
j 个 token 的键的相似度,
矩阵共有 N×N=N2 个元素,这就是自注意力机制时间复杂度为
O(N2) 的直接原因。
后续对
A 做 softmax 归一化,再与
V 相乘得到最终注意力输出,都不会改变这一核心复杂度。
embedding 到 K/Q/V 的完整流程输入 token 经过词嵌入层得到
[N,dmodel] 的 embedding → 通过 3 个独立的线性层(无激活函数)分别映射为
Q,K,V → 计算
QK⊤ 得到
N2 规模的注意力分数 → 分数归一化后与
V 加权求和,输出长度为
N 的注意力特征序列。
我可以帮你整理自注意力与多头注意力的复杂度对比表,清晰展示多头机制对计算量的影响,需要吗?