推理模型从零构建 第5章:自我优化 (Self-Refinement)
1. Self-Refinement 的核心思想
Self-refinement 的灵感来自人类的思考过程:写完一个答案后,回顾检查 → 找出问题 → 修正。对 LLM 来说,这个过程分为三步:
但这里有一个关键问题:修正后的答案不一定比原来的好。我们需要一个评分函数来判断是否接受修正。书中介绍了两种评分方法:
2. 启发式评分
第一种评分方法不需要调用模型,纯粹基于规则判断回答的 “形式质量”:
import math
def heuristic_score(answer, prompt=None, brevity_bonus=500.0,
boxed_bonus=2.0, extract_bonus=1.0):
score = 0.0
# 有 \boxed{} 答案框的回答加分(说明格式规范)
cand = extract_final_candidate(answer, fallback="none")
if cand:
score += boxed_bonus # +2.0
else:
# 至少有一个数字也加分
cand = extract_final_candidate(answer, fallback="number_only")
if cand:
score += extract_bonus # +1.0
# 简洁性奖励:回答越短得分越高(指数衰减)
score += 1.5 * math.exp(-len(answer) / brevity_bonus)
return score
# 两个同样正确的回答(x=83),但长度不同:
# Response 1 (1419字符): score = 2.088
# Response 2 (533字符): score = 2.517 ← 更简洁,得分更高
ℹ️ 为什么奖励简洁性?在数学推理中,过长的回答往往意味着模型在 “绕弯路” 或产生了不相关的内容。简洁而正确的推理链通常质量更高。brevity_bonus=500 意味着 500 字符以内几乎满分,超过 2000 字符后奖励趋近于 0。
3. Log-Probability 置信度评分
更强大的评分方法是利用模型自身的置信度:如果模型对自己生成的每个 token 都很确定(概率高),说明回答质量可能更好。
3.1 为什么用 log-probability 而不是 probability?
序列的联合概率是每个 token 概率的乘积。但连乘很多小数(0.1 × 0.3 × 0.01 × …)很快就会下溢到 0。取对数后,乘法变加法,数值更稳定:
用代码验证:
# "Berlin" 的联合概率:5.9e-08(极小,难以比较)
# "Bridge" 的联合概率:1.0e-13
# 取 log 后:
# "Berlin" 的 joint log-prob: -16.6(可以直接比较大小)
# "Bridge" 的 joint log-prob: -29.9
# "Hamburg" 的 joint log-prob: -18.4
# → Berlin > Hamburg > Bridge,符合直觉
3.2 avg_logprob_answer 评分函数
关键设计:只对回答部分(不含 prompt)的 token 取平均 log-probability。这样避免了 prompt 长度对评分的干扰,也避免了长回答因 token 多而被惩罚:
@torch.inference_mode()
def avg_logprob_answer(model, tokenizer, prompt, answer, device="cpu"):
"""计算模型对回答部分的平均 log-probability"""
# 分别编码 prompt 和 answer
prompt_ids = tokenizer.encode(prompt)
answer_ids = tokenizer.encode(answer)
full_ids = torch.tensor(prompt_ids + answer_ids, device=device)
# 前向传播,得到所有位置的 log-probabilities
logits = model(full_ids.unsqueeze(0)).squeeze(0)
logprobs = torch.log_softmax(logits, dim=-1)
# 只取回答部分的 token 得分
start = len(prompt_ids) - 1
end = full_ids.shape[0] - 1
t_idx = torch.arange(start, end, device=device)
next_tokens = full_ids[start + 1 : end + 1]
next_token_logps = logprobs[t_idx, next_tokens]
# 取平均(归一化掉长度影响)
return torch.mean(next_token_logps)
# 测试:模型对正确答案的置信度更高
avg_logprob_answer(model, tokenizer,
prompt="What is the capital of Germany?",
answer=" The capital of Germany is Berlin.") # → -0.204
avg_logprob_answer(model, tokenizer,
prompt="What is the capital of Germany?",
answer=" The capital of Germany is Bridge.") # → -3.891
直觉:log-prob 越接近 0 说明模型越 “自信”。Berlin (-0.204) 远高于 Bridge (-3.891),说明模型非常确定德国首都是柏林而不是 Bridge。
4. Self-Refinement 循环
有了评分函数,就可以构建完整的 self-refinement 循环。每一轮包含三步:批评 → 修正 → 评分决策。
批评和修正各自有专门的 prompt template:
def make_critique_prompt(raw_prompt, draft):
return (
"You are a meticulous reviewer. Identify logical errors, "
"missing steps, or arithmetic mistakes. If the answer seems "
"correct, say so briefly. Then propose a concise plan to fix issues.\n\n"
f"Question:\n{raw_prompt}\n\n"
f"Draft answer:\n{draft}\n\n"
"Write a short critique and bullet-point fix plan "
"(under ~120 words).\nCritique:"
)
def make_refine_prompt(raw_prompt, draft, critique):
return (
"Revise the answer using the critique. Keep it concise and "
"end with a final boxed result: \\boxed{ANSWER}\n\n"
f"Question:\n{raw_prompt}\n\n"
f"Previous answer:\n{draft}\n\n"
f"Critique:\n{critique}\n\n"
"Revised answer:"
)
完整循环的核心逻辑:
def self_refinement_loop(model, tokenizer, raw_prompt, device,
iterations=2, score_fn=None, ...):
# 生成初始回答
current_full = generate(model, tokenizer, prompt, ...)
current_score = score_fn(answer=current_full, prompt=prompt)
for it in range(iterations):
# 1. 模型自我批评
critique = generate(model, tokenizer,
make_critique_prompt(raw_prompt, current_full), ...)
# 2. 根据批评修正
revised = generate(model, tokenizer,
make_refine_prompt(raw_prompt, current_full, critique), ...)
revised_score = score_fn(answer=revised, prompt=prompt)
# 3. 只在修正版更好时接受
if revised_score >= current_score:
current_full = revised
current_score = revised_score
return current_full
# 实际运行结果(seed=1, 2轮迭代):
# 初始回答: \boxed{18} (错误) score = -0.855
# 第1轮修正: \boxed{83} (正确!) score = -0.226 → 接受
# 第2轮修正: \boxed{83} (正确) score = -1.320 → 拒绝(更差了)
# 最终答案: 83 ✓
⚠️ 观察关键细节:初始回答 \boxed{18} 完全错误,但经过一轮 critique → refine 后修正为 \boxed{83}。第二轮修正虽然答案仍是 83,但 score 从 -0.226 降到 -1.320(模型置信度下降了),所以被自动拒绝。这说明 avg_logprob 评分能有效防止 “越改越差”。
5. 关键收获
- Self-refinement 是 inference-time scaling 的进阶版:不只是多次采样选最好的(self-consistency),而是主动 “反思” 和 “改进”。这更接近人类解题时的思维过程。
- Log-probability 是衡量模型置信度的天然指标:不需要外部模型,LLM 前向传播的输出本身就包含了置信度信息。avg_logprob_answer 只算回答部分的平均值,既公平又实用。
- 评分函数是 “接受/拒绝” 的关门人:没有评分函数的 self-refinement 可能越改越差。heuristic_score 检查格式是否规范,avg_logprob 检查模型是否 “自信”,两者可以组合使用。
- 但 self-refinement 有根本局限:模型用自己来批评自己、改进自己 — 如果模型本身的能力不够,”批评” 可能不准确,”修正” 也可能引入新错误。这就是为什么第6-7章要用 RL 训练来从根本上提升模型的推理能力。
- Log-probability 概念将在 GRPO 中再次出现:第6章的 GRPO 训练中,
sequence_logprob函数和这里的avg_logprob_answer非常相似,都是计算序列的对数概率。理解本章的概念是理解 GRPO 的前提。
ℹ️ 下一章预告:前4-5章的所有技术都是 inference-time 方法 — 模型权重不变。第6章将进入 train-time scaling 领域,使用 GRPO (Group Relative Policy Optimization) 强化学习算法真正训练模型,让它学会自主推理。这是全书的核心章节。
本文是 Build a Reasoning Model (From Scratch) (Sebastian Raschka) 的学习笔记。所有配图版权归原作者所有。代码基于原书示例,有简化和中文注释。