推理模型从零构建 第4章:Inference-Time Scaling

📅 2026-03-02📖 ~9 min readReasoningSamplingSelf-Consistency
本文是 Build a Reasoning Model (From Scratch)(Sebastian Raschka 著)第4章的学习笔记。核心思想:不修改模型权重,仅在推理时投入更多计算来提升准确率。本章实现了三种关键技术:Temperature Scaling(控制输出多样性)、Top-p Sampling(平衡多样性与连贯性)、Self-Consistency 多数投票(采样多个推理链,投票选出最终答案)。三者结合将 base 模型的 MATH-500 准确率从 15% 提升到 52%。
← Ch3: 评估推理模型Ch5: 自我优化 →

1. 核心思想:推理时扩展

传统提升模型性能的方式是 train-time scaling(更大的模型、更多的数据、更长的训练)。Inference-time scaling 是另一条路径:模型权重不变,但在推理时花更多计算来得到更好的结果

Inference-time scaling concept

投入更多推理时计算(如多次采样)可以持续提升准确率 (图源: Reasoning from Scratch)

本章介绍的三种方法都属于 inference-time scaling:

Three inference-time methods

三种 inference-time 方法:CoT prompting、多采样投票、self-refinement (图源: Reasoning from Scratch)
  • Chain-of-thought prompting:在 prompt 末尾加 “Explain step by step”,引导模型展示推理过程,准确率从 15% → 41%
  • Temperature + Top-p + Self-consistency:多次随机采样,投票选最终答案,准确率最高达 52%
  • Self-refinement(第5章内容):让模型评估并改进自己的回答

2. Temperature Scaling

到目前为止,我们用的是 greedy decoding(取 argmax),输出是确定性的。要做多次采样投票,首先需要引入随机性。Temperature 参数控制输出分布的 “锐度”:

Temperature scaling effect

T < 1 让分布更尖锐(更确定),T > 1 让分布更平坦(更随机)(图源: Reasoning from Scratch)

实现非常简单 — 在 softmax 之前,把 logits 除以 temperature:

Python — Temperature Sampling
from reasoning_from_scratch.qwen3 import KVCache

def scale_logits_by_temperature(logits, temperature):
    """logits / T → 控制 softmax 输出的分布锐度"""
    if temperature <= 0:
        raise ValueError("Temperature must be positive")
    return logits / temperature

@torch.inference_mode()
def generate_text_temp_stream_cache(model, token_ids, max_new_tokens,
                                     eos_token_id=None, temperature=0.):
    model.eval()
    cache = KVCache(n_layers=model.cfg["n_layers"])
    model.reset_kv_cache()

    out = model(token_ids, cache=cache)[:, -1]
    for _ in range(max_new_tokens):
        if temperature is None or temperature == 0.0:
            # T=0: 退化为 greedy decoding(argmax)
            next_token = torch.argmax(out, dim=-1, keepdim=True)
        else:
            # T>0: 按概率随机采样
            logits = scale_logits_by_temperature(out, temperature)
            probas = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probas.cpu(), num_samples=1)
            next_token = next_token.to(token_ids.device)

        if eos_token_id is not None and torch.all(next_token == eos_token_id):
            break
        yield next_token
        out = model(next_token, cache=cache)[:, -1]

书中展示了一个直观实验:对 “The capital of Germany is” 这个 prompt,用 T=0.35 采样 1000 次,Berlin 出现 435 次,但也出现了 “______”(模型试图出填空题)。T=5.0 则完全失控,连 “mistress” 都出来了。

ℹ️ Temperature 直觉:T=1.0 时就是标准 softmax;T→0 退化为 argmax(greedy);T→∞ 退化为均匀分布。数学推理通常用 T=0.5~1.0 — 既要有多样性(生成不同推理路径),又不能太随机(保证推理连贯性)。


3. Top-p (Nucleus) Sampling

Temperature 只控制分布的锐度,但即使 T 较低,仍然有小概率采样到完全离谱的 token。Top-p sampling(也叫 nucleus sampling)解决这个问题:只保留累积概率达到阈值 p 的最高概率 token 子集,把其余的概率设为 0。

Top-p sampling concept

Top-p sampling:按概率排序后截断,只保留累积概率 ≤ p 的 token (图源: Reasoning from Scratch)
Python — Top-p Filter
def top_p_filter(probas, top_p):
    """只保留累积概率达到 top_p 的最高概率 token 子集"""
    if top_p is None or top_p >= 1.0:
        return probas  # 不过滤

    # 1. 按概率降序排列
    sorted_probas, sorted_idx = torch.sort(probas, dim=1, descending=True)

    # 2. 计算累积和
    cumprobas = torch.cumsum(sorted_probas, dim=1)

    # 3. 保留前缀累积概率 < top_p 的 token(含跨越阈值的那个)
    prefix = cumprobas - sorted_probas
    keep = prefix < top_p
    keep[:, 0] = True  # 至少保留一个 token

    # 4. 截断:把低概率 token 设为 0
    kept_sorted = torch.where(keep, sorted_probas, torch.zeros_like(sorted_probas))
    filtered = torch.zeros_like(probas).scatter(1, sorted_idx, kept_sorted)

    # 5. 重新归一化
    return filtered / torch.sum(filtered, dim=1, keepdim=True).clamp_min(1e-12)


# 效果对比:T=0.35 采样 1000 次
# 无 top-p:  Berlin=435, ______=209, ____=169, __=158, ...(7种token)
# top_p=0.8: Berlin=534, ______=249, ____=217         (只剩3种token)

Top-p 的妙处在于它是自适应的:当模型很确信时(某个 token 概率 > p),只保留那一个;当模型不确定时,会保留更多候选。这比 top-k(固定保留 k 个)更合理。


4. Self-Consistency 多数投票

有了随机采样能力,就可以实现 Self-Consistency(Wang et al., 2022)—— 核心思想非常直觉:多次采样 → 提取每次的答案 → 投票选出现次数最多的答案

Self-consistency voting

Self-consistency:多条推理链 → 提取答案 → 多数投票 (图源: Reasoning from Scratch)
Python — Self-Consistency Voting
from collections import Counter

def self_consistency_vote(model, tokenizer, prompt, device,
                           num_samples=10, temperature=0.8,
                           top_p=0.9, max_new_tokens=2048, seed=None):
    short_answers = []

    # 1. 采样 num_samples 条不同的推理链
    for i in range(num_samples):
        if seed is not None:
            torch.manual_seed(seed + i + 1)

        answer = generate_text_stream_concat_flex(
            model, tokenizer, prompt, device,
            max_new_tokens=max_new_tokens,
            generate_func=generate_text_top_p_stream_cache,
            temperature=temperature, top_p=top_p,
        )
        # 2. 从每条推理链中提取最终答案
        short = extract_final_candidate(answer, fallback="number_then_full")
        short_answers.append(short)

    # 3. 多数投票(plurality vote)
    counts = Counter(short_answers)
    final_answer = counts.most_common(1)[0][0]
    return final_answer


# 示例:5 次采样 → [83, 22, 54, 83, 61] → 83 获胜(出现2次)
# 加上 CoT prompt 后:[83, 83, 83, 83, 83] → 83(5票一致)

⚠️ CoT + Self-Consistency 的协同效应:单独用 temperature+top-p 采样,5 次中只有 2 次得到正确答案 83。但加上 "Explain step by step" 后,5 次全部正确。CoT prompting 让模型 "写出推理过程",大幅提高了单次采样的正确率,self-consistency 进一步通过投票消除偶发错误。两者结合效果远超单独使用。


5. 完整实验结果

书中在 NVIDIA DGX Spark 上对 MATH-500 完整 500 题做了系统实验:

方法 模型 准确率 耗时
Baseline (greedy) Base 15.2% 10 min
Baseline (greedy) Reasoning 48.2% 182 min
CoT prompting Base 40.6% 85 min
Top-p sampling alone Base 17.8% 31 min
Top-p + Self-consistency (n=3) Base 29.6% 98 min
Top-p + Self-consistency (n=10) Base 31.6% 300 min
CoT + Top-p Base 33.4% 129 min
CoT + Top-p + Self-consistency (n=5) Base 48.0% 453 min
CoT + Top-p + Self-consistency (n=10) Base 52.0% 863 min
CoT + Top-p + Self-consistency (n=3) Reasoning 55.2% 544 min
Results comparison chart

不同方法的准确率与计算成本对比 (图源: Reasoning from Scratch)

几个关键发现:

  • CoT 是最高效的单一改进:仅加一句 "Explain step by step",准确率从 15.2% → 40.6%(2.7x),时间只增加 8.5x(因为输出更长)。
  • Self-consistency 收益递减:n=3 到 n=10,准确率从 42.2% 到 52.0%,但时间从 212 min 增到 863 min。
  • Base + CoT + SC(n=10) 超过了 Reasoning model baseline:52.0% vs 48.2%,说明 inference-time scaling 可以弥补模型训练的不足。
  • 但代价是巨大的计算开销:863 min vs 10 min,86 倍的推理成本。这就是为什么我们还需要第6-7章的 RL 训练。

6. 关键收获

  • Inference-time scaling 是模型训练的有效补充:不需要改动模型权重,就能在推理时大幅提升性能。实际应用中(如 Claude 4),多次并行采样 + 内部评分模型是标准做法。
  • Temperature 和 Top-p 是采样的两个基本旋钮:Temperature 控制分布锐度,Top-p 截断低概率尾部。两者联合使用才能在多样性和连贯性之间取得平衡。
  • Self-consistency 原理简单但有效:「如果模型从不同的推理路径得到相同答案,那个答案更可能正确」— 这个直觉在实验中得到了验证。
  • CoT prompting 是最具性价比的技术:一句话的改动,效果堪比 10 次采样投票。它的本质是让模型 "写出中间步骤",模型在预测下一个 token 时能 attend 到自己之前的推理过程。
  • 计算换精度有上限:Self-consistency 的收益随 n 增大而递减。要进一步提升,需要 train-time scaling(第6-7章的 RL 训练)或更聪明的 inference-time 方法(第5章的 self-refinement)。

ℹ️ 下一章预告:Self-consistency 是 "生成多个答案,选最好的"。第5章将进一步实现 self-refinement — 让模型生成一个答案后,用 log-probability 评估它的质量,然后迭代改进,直到得到更好的结果。

本文是 Build a Reasoning Model (From Scratch) (Sebastian Raschka) 的学习笔记。所有配图版权归原作者所有。代码基于原书示例,有简化和中文注释。

← Ch3: 评估推理模型Ch5: 自我优化 →

Related Posts