Post

Mysterious GPT O1

【o1猜想】LLM inference scaling:MCTS

1 .简介

[OpenAI] o1用了 [Chain-of-thought]做inference,去训练[self-play] RL

  • o1提到要more [reinforcement learning] (train-time compute) 和with more time spent thinking (test-time compute)。
  • 在inference层面,test-time compute的scaling是未来发展的趋势,可能会比单纯地扩展模型参数更经济有效。

论文Scaling LLM Test-Time Compute Optimally can be More Effective than Scaling Model Parameters [1] 提到,inference的scaling有两种形式:

  • refining the proposal distribution:在输入层面修改分布,self-critic / self-refine/ revisions,让模型不断修改自己的答案,串行
  • searching against a verifier :在输出层面修改分布,inference阶段利用[PRM],以Best-of-N/BFS/DFS/MCTS 等搜索方式来提升 reasoning 效果,并行

这里主要介绍searching against a verifier的方案,流程如下:

  • SFT训练policy model
  • 训练PRM,reward model / value model
  • 迭代进行以下流程:
  • 利用policy和reward/value进行MCTS等搜索方法推理,收集数据
  • 用数据去SFT/DPO/PPO 更新policy和reward model/value function

2.TOT

论文Tree of Thoughts: Deliberate Problem Solving with Large Language Models [2] 给出了ToT的概念:

  • 将CoT建模为树状结构,之后可以用BFS/DFS/MCTS等search方法增强推理能力。

树的节点划分可分为sentence-level和token-level(论文AlphaZero-Like Tree-Search can Guide Large Language Model Decoding and Training [3] 的划分方式)

sentence-level:将每个chain-of-thought reasoning视为一个节点,这也是ToT/RAP论文中所采用的方式。sentence-level动作节点提供了一个相对较浅的树(深度低),简化了树搜索过程,但是由于生成sentence的可能较多,需要限制树的宽度,也就在搜索空间上有所限制。

token-level:就是以每个token作为树的一个节点。对于token级动作节点,虽然它可以消除搜索空间上的限制,但增加了树的深度,使搜索变得更加困难。

3.训练PRM

论文Let’s Verify Step by Step [4] 中提出PRM的概念

  • Outcome-supervised Reward Models:奖励模型只对LLM输出的结果打分
  • Process-supervised Reward Models:奖励模型对LLM输出结果的每一步打分

PRM可以给出response中的哪些部分是错的,给出更精准的反馈。

PRM通过对每一行打分,有效定位出错误位置,如果过程中有错误,假设最终Answer是对的,整个LLM输出也会是低分,有效增强模型对于中间过程的学习能力。

训练PRM有几种方式:

  • 用人工标注的数据训练

  • 论文Let’s Verify Step by Step中对每一步都分配一个正、负或中性的标签,正标签表明该步骤是正确的和合理的。负标签表明该步骤要么不正确,要么不合理。中性标签表示歧义。可以推迟关于如何处理歧义的决定,可以将中性标签视为正面或负面。再有了人工标记的数据之后,再监督学习的方式训练PRM。

  • 也可以人工标注出二分类的数据训练:

  • 根据后续路径做hard estimation或soft estimation(论文Math-shepherd: A label-free step-by-step verifier for llms in mathematical reasoning)

  • 用LLM判断每一步是否正确(论文Reasoning with Language Model is Planning with World Model [5] 提到)
  • 学value function(the expected future reward)

  • AlphaZero等论文中会利用value function来进行MCTS。
  • RL中的value function表示在当前状态对未来reward的期望。通常用神经网络去计算value。
  • value function可以在search的过程中直接用reward model更新,也可以先离线学一个value model(论文AlphaZero-Like Tree-Search can Guide Large Language Model Decoding and Training)

有了PRM之后,就可以用search方法来增强推理,search的方法包括best-of-n/BFS/DFS/MCTS等方式。后续主要介绍MCTS。

4.MCTS

蒙特卡洛树搜索(Monte Carlo Tree Search,MCTS)是一种搜索算法,它利用蒙特卡洛模拟来进行搜索MCTS 维护了一个搜索树,这个搜索树记录了之前的搜索过的(状态动作序列)轨迹和相关统计信息(主要是平均回报与访问次数)。

MCTS的特点是可以兼顾探索(explore)和利用(exploit)

  • MCTS不需要对全部的游戏树进行搜索,而是通过统计信息来指导搜索过程,使得它特别适合于搜索空间巨大的问题。
  • 可以在大状态空间和大动作空间的问题中找到高质量的解,而且其性能可以通过增加模拟的次数来提高
  • 在大状态空间和大动作空间中执行时,MCTS 往往需要大量的计算资源。

MCTS可以分为四步:

  1. 选择(Selection):从根节点往下走,根据规则(例如UCB/TUCB)选一个节点,直到一个未扩展的节点
  2. 扩展(Expansion):对当前节点进行展开,获取它的子节点,并随机选择一个节点扩展
  3. 模拟(Simluation):从这个节点Rollout到终止状态,得到reward
  4. 回溯(Backpropagation):把Rollout的结果加到它的所有父节点上

  • 选择:通常根据UCT/PUCT/…等公式选择值最大的节点。综合考虑利用(第一项)和探索(第二项)。V/Q表示对节点的评估值,N表示节点的访问次数。

  • UCT公式:

  • PUTC公式:

  • 选择操作会不断进行,直到叶子节点。

  • 扩展:对当前节点进行展开,获取它的子节点。通常会获取所有的子节点,然后再随机选择一个节点进行后续操作。对于LLM,没法获取到所有的子节点,可以人工规定树的宽度。
  • 模拟:

  • 如果是标准的MCTS算法,会从该节点rollout,不断生成,一直到终止状态。
  • 如果是类似于AlphaZero论文等优化过的MCTS算法(论文AlphaZero-Like Tree-Search can Guide Large Language Model Decoding and Training中提到的MCTS-alpha/MCTS-rollout),不会直到终止状态,而是扩展了之后用value function进行评估,然后就可以将value值回传。

  • 回溯:将rollout的结果/value评估出的结果加到所有父节点上。同时访问次数+1。

MCTS过程会迭代若干次,在达到预定的MCTS迭代次数,会终止算法并从构建的树中选择最终的推理轨迹。选择的方式可能多种多样:

  • 从根节点开始,迭代地选择具有最高value值的动作,直到到达终端。
  • 直接从产生最高奖励的迭代中选择路径,
  • 或者选择访问次数最多的叶节点(以及相应的根到叶路径)

5.迭代训练

  • RAP / ToT论文只增强inference的过程,没有额外再用推理出的数据训练。

  • TS-LLM(AlphaZero-Like Tree-Search can Guide Large Language Model Decoding and Training)会利用迭代更新policy和value/reward:

  • 筛选MCTS采样出的positive数据,SFT更新policy
  • 用所有MCTS数据去更新value/reward

  • ReST-MCTS* [6] 也会迭代更新policy和value:

  • 正样本SFT更新policy
  • 正负样本更新PRM

  • 论文Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning [7] 收集每一个step的正负样本,训练DPO

  • 迭代进行MCTS搜集数据和用数据更新策略的过程:

  • 训练DPO的时候会根据访问次数N加权:

6.总结

目前o1所用到的技术尚不可知,但大体上离不开inference scaling/ scaling test-time,self-play RL,CoT/ToT,MCTS,self-critic/self-refine等方向。其中MCTS类的搜索算法在AlphaGo和AlphaZero时代表现出了强大的能力,又可以很好地与RL算法相结合,应用到RLHF的过程之中,是很有潜力的研究方向。

7.参考文献

[1] Scaling LLM Test-Time Compute Optimally can be More Effective than Scaling Model Parameters

[2] Tree of Thoughts: Deliberate Problem Solving with Large Language Models

[3] AlphaZero-Like Tree-Search can Guide Large Language Model Decoding and Training

[4] Let’s Verify Step by Step

[5] Reasoning with Language Model is Planning with World Model

[6] ReST-MCTS*: LLM Self-Training via Process Reward Guided Tree Search

[7] Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning

This post is licensed under CC BY 4.0 by the author.