你把那个 K, Q, V 矩阵的形状,以及这个 embedding 它这个怎么体现到它是长度的 N 平方,这个矩阵的这个流程再给我讲一下。下车时请勿
自注意力中 K、Q、V 矩阵及复杂度
导出时间:2026/5/21 13:48:02
来源地址:https://www.doubao.com/chat/33386257946145282
消息数量:2
在 Transformer 的自注意力机制中,K(键矩阵)、Q(查询矩阵)、V(值矩阵) 的形状推导与输入 embedding 的维度直接相关,具体流程和 N2 复杂度的来源如下:
- 基础输入与矩阵形状定义设输入序列长度为 N,每个 token 的 embedding 维度为 dmodel,则输入 embedding 矩阵 X 的形状为 [N,dmodel]。自注意力会用 3 个可学习的参数矩阵 WQ,WK,WV(形状均为 [dmodel,dk],其中 dk 为键 / 查询的维度)对输入进行线性变换,得到:
- 查询矩阵 Q=XWQ,形状 [N,dk]
- 键矩阵 K=XWK,形状 [N,dk]
- 值矩阵 V=XWV,形状 [N,dv](通常 dv=dk)
- N2 复杂度的核心来源:注意力分数计算自注意力的核心是计算查询与所有键的相似度,即注意力分数矩阵 A:A=QK⊤其中 K⊤ 是 K 的转置,形状为 [dk,N]。根据矩阵乘法规则,[N,dk]×[dk,N] 的运算结果是一个 [N,N] 的方阵 —— 这个方阵的每个元素 Aij 对应第 i 个 token 的查询与第 j 个 token 的键的相似度,矩阵共有 N×N=N2 个元素,这就是自注意力机制时间复杂度为 O(N2) 的直接原因。后续对 A 做 softmax 归一化,再与 V 相乘得到最终注意力输出,都不会改变这一核心复杂度。
- embedding 到 K/Q/V 的完整流程输入 token 经过词嵌入层得到 [N,dmodel] 的 embedding → 通过 3 个独立的线性层(无激活函数)分别映射为 Q,K,V → 计算 QK⊤ 得到 N2 规模的注意力分数 → 分数归一化后与 V 加权求和,输出长度为 N 的注意力特征序列。
我可以帮你整理自注意力与多头注意力的复杂度对比表,清晰展示多头机制对计算量的影响,需要吗?