推理模型从零构建 第8章:推理蒸馏 (Reasoning Distillation)
1. 为什么需要推理蒸馏
前两章用 GRPO 强化学习训练了推理模型,效果显著(15.2% → 47.4%)。但 RL 训练存在几个实际问题:
- 训练不稳定:第7章花了大量篇幅处理退化、KL 崩溃、reward hacking 等问题。
- 计算成本高:每步需要采样多个 rollout → 前向传播 → 反向传播,比普通 SFT 贵 5–10 倍。
- 超参敏感:learning rate、clip_eps、num_rollouts、temperature 等都需要精心调节。
推理蒸馏 (Reasoning Distillation) 提供了一条更简单的替代路径:让一个大而强的 teacher model(如 DeepSeek-R1 671B)生成高质量的推理数据(包含 <think> 推理过程和最终答案),然后用这些数据对小而快的 student model做标准的 supervised fine-tuning。
ℹ️ DeepSeek-R1 的成功案例:DeepSeek 用 R1 (671B) 生成了 80 万条推理数据,蒸馏到 Qwen 和 Llama 的小模型上。蒸馏后的 DeepSeek-R1-Distill-Qwen-7B 在多个数学 benchmark 上超过了直接用 RL 训练的同等大小模型。蒸馏的核心优势是:小模型可以 “借用” 大模型的推理能力,而不需要自己经历漫长且不稳定的 RL 训练。
2. 蒸馏流程概览
推理蒸馏的流程分为两大阶段:
- 数据生成阶段:用 teacher model 对训练题目生成带推理过程的回答
- SFT 训练阶段:用生成的数据对 student model 做 supervised fine-tuning
数据格式示意:
{
"problem": "A rectangular band formation...",
"gtruth_answer": "98",
"message_thinking": "I need to find the largest number of...",
"message_content": "The function is continuous..."
}
# 组合为完整的训练样本:
complete_answer = f"{data['message_thinking']} \n\n{data['message_content']}"
# → "I need to find the largest number of... \n\nThe function is continuous..."
每条数据包含四个字段:原始题目、标准答案(用于质量过滤)、teacher 的推理过程、teacher 的最终回答。训练时,模型学习的目标是:给定 problem,生成完整的 <think>...</think> + answer。
3. 蒸馏数据生成
书中提供了两种数据生成方式:本地 Ollama 和云端 OpenRouter API。
3.1 本地生成(Ollama)
适合可以在本地运行的中小模型(Qwen3 4B、DeepSeek R1 8B/32B):
# 安装 Ollama: https://ollama.com
# 下载模型: ollama run deepseek-r1:8b
uv run generate_with_ollama.py \
--math_json math_train_sample.json \
--dataset_size 5 \
--model deepseek-r1:8b \
--max_new_tokens 8192 \
--out_file sample_ollama_outputs.json
# 输出:
# 5/5 | MATH-500: 5/5 | ETA: 00s
# Total time: 3.2 min
# Wrote 5 rows to: sample_ollama_outputs.json
关键参数:
--model:选择 teacher 模型。内存 ≥ 30 GB 用deepseek-r1:8b;≥ 60 GB 用deepseek-r1:32b(质量更高)。--max_new_tokens 8192:推理过程可能很长,设置太小会截断回答。可根据内存调整到 2048。--resume:支持断点续传,适合大规模数据生成。
3.2 云端生成(OpenRouter)
适合本地无法运行的大模型(DeepSeek R1 671B、Kimi K2.5 1T):
OPENROUTER_API_KEY="YOUR_KEY" uv run generate_with_openrouter.py \
--math_json math_train_sample.json \
--dataset_size 1000 \
--model deepseek/deepseek-r1 \
--num_processes 50 \
--out_file distillation_data.json
# 成本估算(DeepSeek R1 671B):
# 1000 题 × 平均 1524 output tokens ≈ $3.82
# 输入极便宜:11 tokens/题 × $0.70/M = 可忽略
# 用 --num_processes 50 并行可将 100 小时缩到 ~2 小时
ℹ️ 成本对比:用 DeepSeek R1 671B 生成 1000 题的蒸馏数据只需 ~$3.82。即使生成全部 12,000 题也不到 $50。相比 RL 训练所需的 GPU 小时数(H100 上数十小时),蒸馏的数据生成成本极低。真正的计算成本在后续的 SFT 训练阶段,但 SFT 比 RL 简单且稳定得多。
4. Supervised Fine-Tuning
有了蒸馏数据后,训练过程就是标准的 SFT(本章的完整 SFT 代码正在编写中)。核心步骤:
- 数据过滤:用第3章的
grade_answer验证 teacher 回答是否正确,只保留正确的样本。 - 构造训练样本:将 problem 作为 input,
<think>...</think> + boxed answer作为 target。 - 标准 cross-entropy 训练:和预训练相同的 loss 函数,模型学习在给定 problem 条件下生成完整的推理+回答。
- 可选:RL 后训练:蒸馏后的模型可以再用 GRPO 做少量 RL 训练进一步提升。
import json
# 加载蒸馏数据
with open("distillation_data.json") as f:
data = json.load(f)
# 过滤:只保留 teacher 回答正确的样本
filtered = []
for item in data:
teacher_answer = extract_final_candidate(item["message_content"])
if grade_answer(teacher_answer, item["gtruth_answer"]):
filtered.append({
"input": render_prompt(item["problem"]),
"target": f"{item['message_thinking']} \n\n"
f"{item['message_content']}"
})
print(f"过滤后保留 {len(filtered)}/{len(data)} 条数据")
# 接下来:标准 SFT 训练循环
# for batch in dataloader:
# loss = cross_entropy(model(batch["input"]), batch["target"])
# loss.backward()
# optimizer.step()
5. 蒸馏 vs RL 的权衡
| 维度 | RL (GRPO) | 蒸馏 (Distillation) |
|---|---|---|
| 训练复杂度 | 高(采样+reward+advantage+更新) | 低(标准 SFT) |
| 训练稳定性 | 差(需要大量调参和 tricks) | 好(cross-entropy 训练很稳定) |
| 数据依赖 | 只需题目+答案验证器 | 需要 teacher model 生成数据 |
| 推理能力上限 | 可以超越 teacher(自主探索) | 受限于 teacher 的能力 |
| 计算成本 | 高(每步多次采样+前向+反向) | 数据生成便宜 + SFT 训练便宜 |
| 适用场景 | 训练最强推理模型 | 快速部署中小推理模型 |
实际应用中,两种方法常常组合使用:
- 先蒸馏:用 teacher 数据做 SFT,给 student model 一个好的 “起点”
- 再 RL:在蒸馏基础上做少量 GRPO 训练,进一步提升并可能超越 teacher
这正是 DeepSeek-R1 论文中描述的完整流程。蒸馏提供了稳定的基础能力,RL 提供了自主探索和突破的可能性。
6. 关键收获与展望
- 蒸馏是 “站在巨人肩膀上” 的方法:小模型无需经历漫长的 RL 训练,直接学习大模型的推理模式。DeepSeek 证明了蒸馏后的 7B 模型可以超过直接 RL 训练的同等大小模型。
- 数据质量决定蒸馏效果:teacher 的回答质量是关键。建议:(1) 使用最强的 teacher model;(2) 用验证器过滤错误回答;(3) 保留完整的
<think>推理过程,而不仅仅是最终答案。 - 蒸馏数据生成成本极低:通过 API 调用大模型生成 1000 条数据只需 ~$4,是 RL 训练成本的零头。这使得快速迭代和实验成为可能。
- 蒸馏 + RL 是当前的最佳实践:先蒸馏建立基础能力,再用少量 RL 训练突破 teacher 的上限。纯蒸馏的天花板是 teacher 的能力,而 RL 可以让 student 自主发现更好的推理路径。
- 本章完整代码仍在编写中:目前书中已提供蒸馏数据生成的完整工具链(Ollama + OpenRouter)。SFT 训练的完整实现敬请期待后续更新。
ℹ️ 全书回顾:从第2章的 Qwen3 文本生成,到第3章的评估体系,到第4-5章的 inference-time scaling(采样、投票、self-refinement),再到第6-7章的 GRPO 强化学习,最后到本章的推理蒸馏 —— 我们走完了构建推理模型的完整路径。核心启示:推理能力不是预训练阶段的副产品,而是需要通过专门的后训练阶段(RL 或蒸馏)来激发和强化的。
本文是 Build a Reasoning Model (From Scratch) (Sebastian Raschka) 的学习笔记。本章内容部分基于已发布的补充材料和 DeepSeek-R1 论文,完整 SFT 代码待原书更新后补充。所有配图版权归原作者所有。