推理模型从零构建 第4章:Inference-Time Scaling
1. 核心思想:推理时扩展
传统提升模型性能的方式是 train-time scaling(更大的模型、更多的数据、更长的训练)。Inference-time scaling 是另一条路径:模型权重不变,但在推理时花更多计算来得到更好的结果。
本章介绍的三种方法都属于 inference-time scaling:
- Chain-of-thought prompting:在 prompt 末尾加 “Explain step by step”,引导模型展示推理过程,准确率从 15% → 41%
- Temperature + Top-p + Self-consistency:多次随机采样,投票选最终答案,准确率最高达 52%
- Self-refinement(第5章内容):让模型评估并改进自己的回答
2. Temperature Scaling
到目前为止,我们用的是 greedy decoding(取 argmax),输出是确定性的。要做多次采样投票,首先需要引入随机性。Temperature 参数控制输出分布的 “锐度”:
实现非常简单 — 在 softmax 之前,把 logits 除以 temperature:
from reasoning_from_scratch.qwen3 import KVCache
def scale_logits_by_temperature(logits, temperature):
"""logits / T → 控制 softmax 输出的分布锐度"""
if temperature <= 0:
raise ValueError("Temperature must be positive")
return logits / temperature
@torch.inference_mode()
def generate_text_temp_stream_cache(model, token_ids, max_new_tokens,
eos_token_id=None, temperature=0.):
model.eval()
cache = KVCache(n_layers=model.cfg["n_layers"])
model.reset_kv_cache()
out = model(token_ids, cache=cache)[:, -1]
for _ in range(max_new_tokens):
if temperature is None or temperature == 0.0:
# T=0: 退化为 greedy decoding(argmax)
next_token = torch.argmax(out, dim=-1, keepdim=True)
else:
# T>0: 按概率随机采样
logits = scale_logits_by_temperature(out, temperature)
probas = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probas.cpu(), num_samples=1)
next_token = next_token.to(token_ids.device)
if eos_token_id is not None and torch.all(next_token == eos_token_id):
break
yield next_token
out = model(next_token, cache=cache)[:, -1]
书中展示了一个直观实验:对 “The capital of Germany is” 这个 prompt,用 T=0.35 采样 1000 次,Berlin 出现 435 次,但也出现了 “______”(模型试图出填空题)。T=5.0 则完全失控,连 “mistress” 都出来了。
ℹ️ Temperature 直觉:T=1.0 时就是标准 softmax;T→0 退化为 argmax(greedy);T→∞ 退化为均匀分布。数学推理通常用 T=0.5~1.0 — 既要有多样性(生成不同推理路径),又不能太随机(保证推理连贯性)。
3. Top-p (Nucleus) Sampling
Temperature 只控制分布的锐度,但即使 T 较低,仍然有小概率采样到完全离谱的 token。Top-p sampling(也叫 nucleus sampling)解决这个问题:只保留累积概率达到阈值 p 的最高概率 token 子集,把其余的概率设为 0。
def top_p_filter(probas, top_p):
"""只保留累积概率达到 top_p 的最高概率 token 子集"""
if top_p is None or top_p >= 1.0:
return probas # 不过滤
# 1. 按概率降序排列
sorted_probas, sorted_idx = torch.sort(probas, dim=1, descending=True)
# 2. 计算累积和
cumprobas = torch.cumsum(sorted_probas, dim=1)
# 3. 保留前缀累积概率 < top_p 的 token(含跨越阈值的那个)
prefix = cumprobas - sorted_probas
keep = prefix < top_p
keep[:, 0] = True # 至少保留一个 token
# 4. 截断:把低概率 token 设为 0
kept_sorted = torch.where(keep, sorted_probas, torch.zeros_like(sorted_probas))
filtered = torch.zeros_like(probas).scatter(1, sorted_idx, kept_sorted)
# 5. 重新归一化
return filtered / torch.sum(filtered, dim=1, keepdim=True).clamp_min(1e-12)
# 效果对比:T=0.35 采样 1000 次
# 无 top-p: Berlin=435, ______=209, ____=169, __=158, ...(7种token)
# top_p=0.8: Berlin=534, ______=249, ____=217 (只剩3种token)
Top-p 的妙处在于它是自适应的:当模型很确信时(某个 token 概率 > p),只保留那一个;当模型不确定时,会保留更多候选。这比 top-k(固定保留 k 个)更合理。
4. Self-Consistency 多数投票
有了随机采样能力,就可以实现 Self-Consistency(Wang et al., 2022)—— 核心思想非常直觉:多次采样 → 提取每次的答案 → 投票选出现次数最多的答案。
from collections import Counter
def self_consistency_vote(model, tokenizer, prompt, device,
num_samples=10, temperature=0.8,
top_p=0.9, max_new_tokens=2048, seed=None):
short_answers = []
# 1. 采样 num_samples 条不同的推理链
for i in range(num_samples):
if seed is not None:
torch.manual_seed(seed + i + 1)
answer = generate_text_stream_concat_flex(
model, tokenizer, prompt, device,
max_new_tokens=max_new_tokens,
generate_func=generate_text_top_p_stream_cache,
temperature=temperature, top_p=top_p,
)
# 2. 从每条推理链中提取最终答案
short = extract_final_candidate(answer, fallback="number_then_full")
short_answers.append(short)
# 3. 多数投票(plurality vote)
counts = Counter(short_answers)
final_answer = counts.most_common(1)[0][0]
return final_answer
# 示例:5 次采样 → [83, 22, 54, 83, 61] → 83 获胜(出现2次)
# 加上 CoT prompt 后:[83, 83, 83, 83, 83] → 83(5票一致)
⚠️ CoT + Self-Consistency 的协同效应:单独用 temperature+top-p 采样,5 次中只有 2 次得到正确答案 83。但加上 "Explain step by step" 后,5 次全部正确。CoT prompting 让模型 "写出推理过程",大幅提高了单次采样的正确率,self-consistency 进一步通过投票消除偶发错误。两者结合效果远超单独使用。
5. 完整实验结果
书中在 NVIDIA DGX Spark 上对 MATH-500 完整 500 题做了系统实验:
| 方法 | 模型 | 准确率 | 耗时 |
|---|---|---|---|
| Baseline (greedy) | Base | 15.2% | 10 min |
| Baseline (greedy) | Reasoning | 48.2% | 182 min |
| CoT prompting | Base | 40.6% | 85 min |
| Top-p sampling alone | Base | 17.8% | 31 min |
| Top-p + Self-consistency (n=3) | Base | 29.6% | 98 min |
| Top-p + Self-consistency (n=10) | Base | 31.6% | 300 min |
| CoT + Top-p | Base | 33.4% | 129 min |
| CoT + Top-p + Self-consistency (n=5) | Base | 48.0% | 453 min |
| CoT + Top-p + Self-consistency (n=10) | Base | 52.0% | 863 min |
| CoT + Top-p + Self-consistency (n=3) | Reasoning | 55.2% | 544 min |
几个关键发现:
- CoT 是最高效的单一改进:仅加一句 "Explain step by step",准确率从 15.2% → 40.6%(2.7x),时间只增加 8.5x(因为输出更长)。
- Self-consistency 收益递减:n=3 到 n=10,准确率从 42.2% 到 52.0%,但时间从 212 min 增到 863 min。
- Base + CoT + SC(n=10) 超过了 Reasoning model baseline:52.0% vs 48.2%,说明 inference-time scaling 可以弥补模型训练的不足。
- 但代价是巨大的计算开销:863 min vs 10 min,86 倍的推理成本。这就是为什么我们还需要第6-7章的 RL 训练。
6. 关键收获
- Inference-time scaling 是模型训练的有效补充:不需要改动模型权重,就能在推理时大幅提升性能。实际应用中(如 Claude 4),多次并行采样 + 内部评分模型是标准做法。
- Temperature 和 Top-p 是采样的两个基本旋钮:Temperature 控制分布锐度,Top-p 截断低概率尾部。两者联合使用才能在多样性和连贯性之间取得平衡。
- Self-consistency 原理简单但有效:「如果模型从不同的推理路径得到相同答案,那个答案更可能正确」— 这个直觉在实验中得到了验证。
- CoT prompting 是最具性价比的技术:一句话的改动,效果堪比 10 次采样投票。它的本质是让模型 "写出中间步骤",模型在预测下一个 token 时能 attend 到自己之前的推理过程。
- 计算换精度有上限:Self-consistency 的收益随 n 增大而递减。要进一步提升,需要 train-time scaling(第6-7章的 RL 训练)或更聪明的 inference-time 方法(第5章的 self-refinement)。
ℹ️ 下一章预告:Self-consistency 是 "生成多个答案,选最好的"。第5章将进一步实现 self-refinement — 让模型生成一个答案后,用 log-probability 评估它的质量,然后迭代改进,直到得到更好的结果。
本文是 Build a Reasoning Model (From Scratch) (Sebastian Raschka) 的学习笔记。所有配图版权归原作者所有。代码基于原书示例,有简化和中文注释。