推理模型从零构建 第5章:自我优化 (Self-Refinement)

📅 2026-03-02📖 ~8 min readReasoningSelf-RefinementLog-Probability
本文是 Build a Reasoning Model (From Scratch)(Sebastian Raschka 著)第5章的学习笔记。上一章通过多次采样 + 投票来挑选最佳答案,但模型本身并没有 “反思” 的过程。本章实现了一个更智能的策略:让模型生成答案 → 自我批评 → 修正 → 用置信度评分决定是否接受修正。核心是两个评分函数:heuristic_score(基于规则)和 avg_logprob_answer(基于模型置信度)。
← Ch4: Inference-Time ScalingCh6: GRPO强化学习 →

1. Self-Refinement 的核心思想

Self-refinement 的灵感来自人类的思考过程:写完一个答案后,回顾检查 → 找出问题 → 修正。对 LLM 来说,这个过程分为三步:

Self-refinement concept

Self-refinement 三步法:生成初始回答 → 模型自我批评 → 根据批评修正 (图源: Reasoning from Scratch)

但这里有一个关键问题:修正后的答案不一定比原来的好。我们需要一个评分函数来判断是否接受修正。书中介绍了两种评分方法:

Scoring methods overview

两种评分方法:启发式规则评分 和 基于 log-probability 的置信度评分 (图源: Reasoning from Scratch)

2. 启发式评分

第一种评分方法不需要调用模型,纯粹基于规则判断回答的 “形式质量”:

Python — Heuristic Score
import math

def heuristic_score(answer, prompt=None, brevity_bonus=500.0,
                     boxed_bonus=2.0, extract_bonus=1.0):
    score = 0.0

    # 有 \boxed{} 答案框的回答加分(说明格式规范)
    cand = extract_final_candidate(answer, fallback="none")
    if cand:
        score += boxed_bonus  # +2.0
    else:
        # 至少有一个数字也加分
        cand = extract_final_candidate(answer, fallback="number_only")
        if cand:
            score += extract_bonus  # +1.0

    # 简洁性奖励:回答越短得分越高(指数衰减)
    score += 1.5 * math.exp(-len(answer) / brevity_bonus)
    return score

# 两个同样正确的回答(x=83),但长度不同:
# Response 1 (1419字符): score = 2.088
# Response 2 (533字符):  score = 2.517 ← 更简洁,得分更高

ℹ️ 为什么奖励简洁性?在数学推理中,过长的回答往往意味着模型在 “绕弯路” 或产生了不相关的内容。简洁而正确的推理链通常质量更高。brevity_bonus=500 意味着 500 字符以内几乎满分,超过 2000 字符后奖励趋近于 0。


3. Log-Probability 置信度评分

更强大的评分方法是利用模型自身的置信度:如果模型对自己生成的每个 token 都很确定(概率高),说明回答质量可能更好。

3.1 为什么用 log-probability 而不是 probability?

序列的联合概率是每个 token 概率的乘积。但连乘很多小数(0.1 × 0.3 × 0.01 × …)很快就会下溢到 0。取对数后,乘法变加法,数值更稳定:

Joint probability vs log-probability

P(序列) = 各token概率之积;log P(序列) = 各token log概率之和 (图源: Reasoning from Scratch)

用代码验证:

Python — 概率 vs log 概率
# "Berlin" 的联合概率:5.9e-08(极小,难以比较)
# "Bridge" 的联合概率:1.0e-13

# 取 log 后:
# "Berlin" 的 joint log-prob: -16.6(可以直接比较大小)
# "Bridge" 的 joint log-prob: -29.9
# "Hamburg" 的 joint log-prob: -18.4
# → Berlin > Hamburg > Bridge,符合直觉

3.2 avg_logprob_answer 评分函数

关键设计:只对回答部分(不含 prompt)的 token 取平均 log-probability。这样避免了 prompt 长度对评分的干扰,也避免了长回答因 token 多而被惩罚:

avg_logprob scoring

只对回答部分的 token 计算平均 log-prob(灰色=prompt,蓝色=回答)(图源: Reasoning from Scratch)
Python — avg_logprob_answer
@torch.inference_mode()
def avg_logprob_answer(model, tokenizer, prompt, answer, device="cpu"):
    """计算模型对回答部分的平均 log-probability"""
    # 分别编码 prompt 和 answer
    prompt_ids = tokenizer.encode(prompt)
    answer_ids = tokenizer.encode(answer)
    full_ids = torch.tensor(prompt_ids + answer_ids, device=device)

    # 前向传播,得到所有位置的 log-probabilities
    logits = model(full_ids.unsqueeze(0)).squeeze(0)
    logprobs = torch.log_softmax(logits, dim=-1)

    # 只取回答部分的 token 得分
    start = len(prompt_ids) - 1
    end = full_ids.shape[0] - 1
    t_idx = torch.arange(start, end, device=device)
    next_tokens = full_ids[start + 1 : end + 1]
    next_token_logps = logprobs[t_idx, next_tokens]

    # 取平均(归一化掉长度影响)
    return torch.mean(next_token_logps)


# 测试:模型对正确答案的置信度更高
avg_logprob_answer(model, tokenizer,
    prompt="What is the capital of Germany?",
    answer=" The capital of Germany is Berlin.")  # → -0.204

avg_logprob_answer(model, tokenizer,
    prompt="What is the capital of Germany?",
    answer=" The capital of Germany is Bridge.")  # → -3.891

直觉:log-prob 越接近 0 说明模型越 “自信”。Berlin (-0.204) 远高于 Bridge (-3.891),说明模型非常确定德国首都是柏林而不是 Bridge。


4. Self-Refinement 循环

有了评分函数,就可以构建完整的 self-refinement 循环。每一轮包含三步:批评 → 修正 → 评分决策

Self-refinement loop

完整 self-refinement 循环:初始回答 → [批评 → 修正 → 评分]×N (图源: Reasoning from Scratch)

批评和修正各自有专门的 prompt template:

Python — 批评和修正 Prompt
def make_critique_prompt(raw_prompt, draft):
    return (
        "You are a meticulous reviewer. Identify logical errors, "
        "missing steps, or arithmetic mistakes. If the answer seems "
        "correct, say so briefly. Then propose a concise plan to fix issues.\n\n"
        f"Question:\n{raw_prompt}\n\n"
        f"Draft answer:\n{draft}\n\n"
        "Write a short critique and bullet-point fix plan "
        "(under ~120 words).\nCritique:"
    )

def make_refine_prompt(raw_prompt, draft, critique):
    return (
        "Revise the answer using the critique. Keep it concise and "
        "end with a final boxed result: \\boxed{ANSWER}\n\n"
        f"Question:\n{raw_prompt}\n\n"
        f"Previous answer:\n{draft}\n\n"
        f"Critique:\n{critique}\n\n"
        "Revised answer:"
    )

完整循环的核心逻辑:

Python — Self-Refinement Loop(简化)
def self_refinement_loop(model, tokenizer, raw_prompt, device,
                          iterations=2, score_fn=None, ...):
    # 生成初始回答
    current_full = generate(model, tokenizer, prompt, ...)
    current_score = score_fn(answer=current_full, prompt=prompt)

    for it in range(iterations):
        # 1. 模型自我批评
        critique = generate(model, tokenizer,
            make_critique_prompt(raw_prompt, current_full), ...)

        # 2. 根据批评修正
        revised = generate(model, tokenizer,
            make_refine_prompt(raw_prompt, current_full, critique), ...)
        revised_score = score_fn(answer=revised, prompt=prompt)

        # 3. 只在修正版更好时接受
        if revised_score >= current_score:
            current_full = revised
            current_score = revised_score

    return current_full


# 实际运行结果(seed=1, 2轮迭代):
# 初始回答: \boxed{18}  (错误)   score = -0.855
# 第1轮修正: \boxed{83}  (正确!)  score = -0.226 → 接受
# 第2轮修正: \boxed{83}  (正确)   score = -1.320 → 拒绝(更差了)
# 最终答案: 83 ✓

⚠️ 观察关键细节:初始回答 \boxed{18} 完全错误,但经过一轮 critique → refine 后修正为 \boxed{83}。第二轮修正虽然答案仍是 83,但 score 从 -0.226 降到 -1.320(模型置信度下降了),所以被自动拒绝。这说明 avg_logprob 评分能有效防止 “越改越差”。


5. 关键收获

  • Self-refinement 是 inference-time scaling 的进阶版:不只是多次采样选最好的(self-consistency),而是主动 “反思” 和 “改进”。这更接近人类解题时的思维过程。
  • Log-probability 是衡量模型置信度的天然指标:不需要外部模型,LLM 前向传播的输出本身就包含了置信度信息。avg_logprob_answer 只算回答部分的平均值,既公平又实用。
  • 评分函数是 “接受/拒绝” 的关门人:没有评分函数的 self-refinement 可能越改越差。heuristic_score 检查格式是否规范,avg_logprob 检查模型是否 “自信”,两者可以组合使用。
  • 但 self-refinement 有根本局限:模型用自己来批评自己、改进自己 — 如果模型本身的能力不够,”批评” 可能不准确,”修正” 也可能引入新错误。这就是为什么第6-7章要用 RL 训练来从根本上提升模型的推理能力
  • Log-probability 概念将在 GRPO 中再次出现:第6章的 GRPO 训练中,sequence_logprob 函数和这里的 avg_logprob_answer 非常相似,都是计算序列的对数概率。理解本章的概念是理解 GRPO 的前提。

ℹ️ 下一章预告:前4-5章的所有技术都是 inference-time 方法 — 模型权重不变。第6章将进入 train-time scaling 领域,使用 GRPO (Group Relative Policy Optimization) 强化学习算法真正训练模型,让它学会自主推理。这是全书的核心章节。

本文是 Build a Reasoning Model (From Scratch) (Sebastian Raschka) 的学习笔记。所有配图版权归原作者所有。代码基于原书示例,有简化和中文注释。

← Ch4: Inference-Time ScalingCh6: GRPO强化学习 →

Related Posts