从零构建LLM 第3章:注意力机制
1. 注意力机制是什么
在传统 NLP 模型中,每个词的表示是固定的 —— “bank” 在 “river bank” 和 “bank account” 中的向量完全相同。自注意力 (Self-Attention) 改变了这一点:它让每个 token 根据当前上下文,动态地聚合序列中其他 token 的信息。
一句话总结:注意力机制就是让每个 token 问一句「在这个上下文里,我应该关注谁?」
本章从最简单的形式出发,经过四次迭代升级到 GPT 的完整注意力机制:
- 简单自注意力 — 直接用输入向量的点积计算相关性,无可训练参数
- 可训练自注意力 — 引入 Query/Key/Value 权重矩阵,让模型学习「关注什么」
- 因果注意力 — 添加掩码,防止偷看未来 token(自回归生成的前提)
- 多头注意力 — 多个注意力头并行计算,捕捉不同类型的语义关系
2. 简单自注意力:用点积衡量相关性
自注意力的核心思想出奇地简单:用点积衡量两个向量的相似度。两个向量越「对齐」,点积越大,说明它们越相关。整个计算只需要三步:
- 计算注意力分数 (Attention Scores):将查询 token 与所有 token 做点积
- 归一化为权重 (Attention Weights):用 softmax 将分数转为概率分布(和为 1)
- 加权求和 (Context Vector):用权重对所有 token 做加权平均,得到融合上下文的新表示
import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x_1)
[0.55, 0.87, 0.66], # journey (x_2)
[0.57, 0.85, 0.64], # starts (x_3)
[0.22, 0.58, 0.33], # with (x_4)
[0.77, 0.25, 0.10], # one (x_5)
[0.05, 0.80, 0.55]] # step (x_6)
)
# 以 x_2 ("journey") 为查询,计算与每个 token 的点积
query = inputs[1]
attn_scores = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
attn_scores[i] = torch.dot(query, x_i)
# softmax 归一化 → 注意力权重(概率分布,和为1)
attn_weights = torch.softmax(attn_scores, dim=0)
# tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
# 加权求和 → 上下文向量(融合了所有 token 信息的新表示)
context_vec = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
context_vec += attn_weights[i] * x_i
# tensor([0.4419, 0.6515, 0.5683])
ℹ️ 直觉理解:注意力权重本质上是一个概率分布。在上面的例子中,x_2 (“journey”) 对 x_2 自身和 x_3 (“starts”) 给出了最高权重 (0.24, 0.23) — 这些是与它语义最相关的 token。上下文向量则是所有 token 按这个分布做的「加权平均」,包含了整个序列的信息。
3. 可训练的自注意力:Query/Key/Value
简单自注意力有个问题:输入向量同时充当「提问者」和「被查询对象」,角色没有区分。书中引入了三个可训练的权重矩阵来分离这些角色:
- Query (查询):「我在找什么信息?」— 定义当前 token 的需求
- Key (键):「我能提供什么信息?」— 描述每个 token 可以匹配的特征
- Value (值):「我实际包含什么内容?」— 一旦匹配成功,实际传递的信息
计算公式:Attention(Q, K, V) = softmax(Q · KT / √dk) · V
其中除以 √dk(缩放因子)是一个关键细节:当嵌入维度 dk 较大时,点积的数值会变得很大,导致 softmax 输出趋近于 one-hot 分布(梯度极小)。缩放可以保持梯度在合理范围内。
import torch.nn as nn
class SelfAttention_v2(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
keys = self.W_key(x) # 所有 token 的 Key
queries = self.W_query(x) # 所有 token 的 Query
values = self.W_value(x) # 所有 token 的 Value
# 缩放点积注意力: Q · K^T / sqrt(d_k)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
# 用注意力权重对 Value 加权求和
context_vec = attn_weights @ values
return context_vec
⚠️ 为什么需要缩放?假设 Q 和 K 的每个元素都是均值为 0、方差为 1 的随机变量。它们的点积的方差 = dk(维度越高,方差越大)。除以 √dk 可以把方差重新拉回 1,防止 softmax 进入梯度饱和区。这就是 “Scaled Dot-Product Attention” 中 “Scaled” 的由来。
4. 因果注意力:不能偷看未来
GPT 是自回归 (Autoregressive) 模型 — 生成文本时,第 t 个 token 只能看到前 t−1 个 token。但标准自注意力让每个 token 都能看到整个序列(包括「未来」的 token)。
因果掩码 (Causal Mask) 解决了这个问题:用一个上三角矩阵将「未来位置」的注意力分数设为 −∞,softmax 后这些位置的权重自然变为 0。
# 假设已经计算了 6x6 的注意力分数矩阵
context_length = attn_scores.shape[0]
# 创建上三角掩码:对角线以上全为 1
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
# tensor([[0, 1, 1, 1, 1, 1],
# [0, 0, 1, 1, 1, 1],
# [0, 0, 0, 1, 1, 1],
# [0, 0, 0, 0, 1, 1],
# [0, 0, 0, 0, 0, 1],
# [0, 0, 0, 0, 0, 0]])
# 将掩码位置设为 -inf → softmax 后变为 0
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
# 效果:每个 token 只能看到自己和「过去」的 token
# row 0: [1.00, 0.00, 0.00, 0.00, 0.00, 0.00] ← 只看自己
# row 1: [0.52, 0.48, 0.00, 0.00, 0.00, 0.00] ← 看前2个
# row 2: [0.34, 0.33, 0.33, 0.00, 0.00, 0.00] ← 看前3个 ...
书中还在注意力权重上加了 Dropout,随机将部分权重置零。这起到正则化的作用,防止模型过度依赖某些固定的注意力模式。在训练时开启,推理时关闭。
5. 多头注意力:让模型并行思考
如果只有一个注意力头,模型只能学习一种「关注模式」。但语言理解需要同时捕捉多种关系 — 语法依赖、语义相似、指代消解等。多头注意力 (Multi-Head Attention) 的做法是:把 Q/K/V 的维度拆分成多个头,每个头独立计算注意力,最后将结果拼接并投影。
| 方面 | 单头注意力 | 多头注意力 (GPT-2, 12头) |
|---|---|---|
| Q/K/V 维度 | 768 | 64 per head × 12 heads = 768 |
| 关注模式 | 1 种 | 12 种(语法 / 语义 / 位置 …) |
| 参数量 | 3 × 768 × 768 | 相同(拆分,不是复制) |
| 输出 | 单个 context vector | 12 个拼接 + 线性投影 |
关键实现细节:这不是复制 12 份注意力模块(那样参数量会 ×12),而是将现有的 d_out 维度拆分为 num_heads × head_dim。比如 GPT-2 用 768 维和 12 个头,每个头只处理 64 维 — 总参数量不变。
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out 必须能被 num_heads 整除"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # 每个头的维度
# Q/K/V 权重矩阵(所有头共享一组大矩阵)
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # 输出投影
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length), diagonal=1)
)
def forward(self, x):
b, num_tokens, d_in = x.shape
# 计算 Q/K/V 并拆分为多个头
# [b, seq, d_out] → [b, num_heads, seq, head_dim]
keys = self.W_key(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
queries = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
values = self.W_value(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
# 每个头独立计算因果注意力
attn_scores = queries @ keys.transpose(2, 3)
attn_scores.masked_fill_(
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
# 合并所有头: [b, num_heads, seq, head_dim] → [b, seq, d_out]
context_vec = (attn_weights @ values).transpose(1, 2).contiguous()
context_vec = context_vec.view(b, num_tokens, self.d_out)
return self.out_proj(context_vec) # 最终线性投影
6. 为什么这很重要
注意力机制不仅是 Transformer 的核心组件,也是理解现代 LLM 优化技术的基础:
- KV Cache:因果注意力的单向性质意味着已计算的 Key/Value 可以缓存复用。这就是 LLM 推理时「首 token 慢、后续 token 快」(prefill vs decode) 的原因。
- Flash Attention:标准注意力需要存储完整的 n×n 注意力矩阵。Flash Attention 通过分块计算 (tiling) 和在线 softmax 避免了 O(n²) 的显存开销,是目前所有主流推理框架的标配。
- 注意力可视化:分析 attention weights 可以帮助理解模型在「关注」什么。虽然不能完全解释推理过程,但在调试和 interpretability 研究中很有价值。
- 长上下文挑战:自注意力的计算复杂度是 O(n²),这是限制上下文窗口的根本原因。从 GPT-2 的 1024 token 到现代模型的 128K+ token,背后是大量的注意力效率优化。
- 位置信息:注意力本身是置换不变的 (permutation invariant) — 它不知道 token 的顺序。这就是上一章位置嵌入至关重要的原因:没有它,模型分不清「猫吃鱼」和「鱼吃猫」。
ℹ️ 总结 — 第3章注意力进化路线:简单点积注意力(无参数)→ 可训练 Q/K/V + 缩放 → 因果掩码 + Dropout → 多头并行计算 + 输出投影。这四步构成了 GPT 中 Multi-Head Causal Attention 的完整实现,也是下一章构建完整 GPT 模型的核心积木。
本文是 Build a Large Language Model From Scratch (Sebastian Raschka) 的学习笔记。所有配图版权归原作者所有。代码基于原书示例,有简化和中文注释。