从零构建LLM 第6章:文本分类微调
1. 从生成到分类:为什么 GPT 能做分类
GPT 是一个生成模型,它的输出是每个位置的下一个 token 概率。但每一层 TransformerBlock 都在做同一件事:提取和融合上下文信息。经过 12 层处理后,最后一个 token 的隐状态已经「看过」了整个输入序列 — 它可以作为整个文本的压缩表示。
要把 GPT 变成分类器,只需要两步改动:
- 替换输出头:原来的输出层映射到词汇表大小 (50257),换成映射到类别数 (2)
- 只取最后一个 token 的输出:用它做分类决策,因为它的隐状态汇聚了最完整的上下文
ℹ️ 为什么是最后一个 token?在因果注意力(第3章)中,只有最后一个 token 能通过注意力机制「看到」前面所有 token。它的隐状态等效于对整个序列做了 pooling。相比之下,BERT 使用 [CLS] token(第一个位置)做分类,因为 BERT 用的是双向注意力。
2. 数据准备:垃圾邮件数据集
书中使用 SMS 垃圾邮件数据集做演示。关键步骤包括:
- 数据平衡:原始数据中 “not spam” 远多于 “spam”,需要下采样到 1:1 比例
- 编码:用 tiktoken 将文本转为 token ID 序列
- 填充 (Padding):不同文本长度不同,需要填充到统一长度才能组成 batch
- 划分:训练集 / 验证集 / 测试集
class SpamDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None,
pad_token_id=50256):
self.data = pd.read_csv(csv_file)
# 编码所有文本
self.encoded_texts = [
tokenizer.encode(text) for text in self.data["Text"]
]
if max_length is None:
self.max_length = self._longest_encoded_length()
else:
self.max_length = max_length
# 截断过长的序列
self.encoded_texts = [
encoded[:max_length] for encoded in self.encoded_texts
]
# 用 <|endoftext|> (50256) 填充到统一长度
self.encoded_texts = [
encoded + [pad_token_id] * (self.max_length - len(encoded))
for encoded in self.encoded_texts
]
def __getitem__(self, index):
encoded = self.encoded_texts[index]
label = self.data.iloc[index]["Label"]
return (
torch.tensor(encoded, dtype=torch.long),
torch.tensor(label, dtype=torch.long)
)
def __len__(self):
return len(self.data)
3. 模型改造:换头与冻结层
模型改造的核心原则是:尽量保留预训练学到的知识,只调整最少的参数。书中的做法是:
- 加载 OpenAI 的 GPT-2 预训练权重
- 替换输出头(
out_head)为二分类线性层 - 冻结所有参数,只解冻最后一个 TransformerBlock +
final_norm+ 新分类头
# 1. 加载预训练 GPT-2 权重
model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)
# 2. 替换输出头:词汇表预测 → 二分类
num_classes = 2
model.out_head = torch.nn.Linear(
in_features=BASE_CONFIG["emb_dim"], # 768
out_features=num_classes # 2 (spam / not spam)
)
# 3. 冻结所有层的参数
for param in model.parameters():
param.requires_grad = False
# 4. 只解冻需要微调的部分
for param in model.trf_blocks[-1].parameters(): # 最后一个 TransformerBlock
param.requires_grad = True
for param in model.final_norm.parameters(): # 最终 LayerNorm
param.requires_grad = True
for param in model.out_head.parameters(): # 新分类头
param.requires_grad = True
⚠️ 冻结多少层?冻结越多层,训练越快、过拟合风险越低,但表达能力也越受限。书中实验发现:只解冻最后一个 TransformerBlock 就能达到 95%+ 准确率。对于数据量小的任务,这种「保守微调」策略通常效果最好。数据量大时可以考虑解冻更多层。
4. 训练与推理
训练流程与预训练类似,但有一个关键区别:只取最后一个 token 的 logits 做分类。
def train_classifier_simple(model, train_loader, val_loader,
optimizer, device, num_epochs):
for epoch in range(num_epochs):
model.train()
for input_batch, target_batch in train_loader:
optimizer.zero_grad()
input_batch = input_batch.to(device)
target_batch = target_batch.to(device)
logits = model(input_batch)[:, -1, :] # 只取最后一个 token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
loss.backward()
optimizer.step()
# 评估准确率
train_acc = calc_accuracy_loader(train_loader, model, device)
val_acc = calc_accuracy_loader(val_loader, model, device)
print(f"Ep {epoch+1}: Train {train_acc:.2%}, Val {val_acc:.2%}")
# --- 推理函数 ---
def classify_review(text, model, tokenizer, device, max_length=None):
model.eval()
input_ids = tokenizer.encode(text)[:max_length]
# 填充到统一长度
input_ids += [50256] * (max_length - len(input_ids))
input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)
with torch.no_grad():
logits = model(input_tensor)[:, -1, :]
predicted_label = torch.argmax(logits, dim=-1).item()
return "spam" if predicted_label == 1 else "not spam"
# 示例
print(classify_review("You are a winner! Claim your prize now", ...))
# → "spam"
书中的实验结果:
| 数据集 | 准确率 |
|---|---|
| 训练集 | 97.21% |
| 验证集 | 97.32% |
| 测试集 | 95.67% |
仅微调最后一个 TransformerBlock 和分类头,就能在垃圾邮件分类上达到 95%+ 的测试准确率。这充分体现了预训练表示的迁移能力— GPT 在预训练中学到的语言知识,可以直接复用到下游分类任务。
5. 为什么这很重要
本章展示了 LLM 最实用的应用模式之一 — 分类微调。以下是几个关键启示:
- 迁移学习的威力:只需要微调最后几层和一个新的分类头,就能在小数据集上获得优秀的分类性能。预训练阶段学到的语言理解能力(词义、语法、上下文)可以直接迁移。
- 数据效率:整个 SMS 数据集只有几千条样本。如果从头训练一个分类器,这个数据量远远不够。但基于预训练 GPT 微调,几千条样本就足以达到 95%+ 准确率。
- 层冻结策略:这是一个实用技巧 — 数据量小时冻结更多层(防止过拟合),数据量大时解冻更多层(增加表达能力)。这个权衡在实际项目中非常常见。
- 分类 vs 生成的选择:对于需要明确判断的任务(情感分析、spam 检测、意图识别),分类微调比让 LLM 用文本方式回答更可靠、更高效。输出是一个概率分布而非自由文本,不存在幻觉问题。
- 现代实践:在实际工作中,分类微调通常用 BERT/RoBERTa 等编码器模型(因为双向注意力天然更适合理解任务)。但本章的方法论 — 换头、冻结、微调 — 同样适用于所有模型架构。
ℹ️ 总结 — 第6章分类微调流程:预训练 GPT → 替换输出头 (vocab_size → num_classes) → 冻结大部分层 → 只解冻最后 TransformerBlock + LayerNorm + 分类头 → 用标注数据训练 → 取最后 token 的 logits 做分类决策。下一章将探索 LLM 的另一种微调模式 — 指令微调,让模型学会遵循指令。
本文是 Build a Large Language Model From Scratch (Sebastian Raschka) 的学习笔记。所有配图版权归原作者所有。代码基于原书示例,有简化和中文注释。