田渊栋团队论文火了!连续思维链优于CoT,打开LLM推理新范式
一个非常简单的更改,就能提高 LLM 推理能力。
在认知科学领域,关于语言是用于思考还是用于交流的辩论一直持续。
随着 LLM 和 CoT 的兴起,语言已经成为机器推理的默认媒介 —— 但它真的是最佳方法吗?
一般而言,LLM 被限制在语言空间(language space)内进行推理,并通过思维链(CoT)来表达推理过程,从而解决复杂的推理问题。
然而,语言空间可能并不总是最适合推理的。例如,很多单词 token 主要用于文本连贯性,而不是推理本身,而一些关键 token 则需要复杂的规划,这种差异给 LLM 带来巨大的挑战。
为了探索 LLM 在不受限制潜在空间中的推理潜力,而非使用自然语言,来自 Meta、加州大学圣地亚哥分校的研究者提出了一种新的范式 ——Coconut(连续思维链,Chain of Continuous Thought),来探索 LLM 在潜在空间中的推理。
论文标题:Training Large Language Models to Reason in a Continuous Latent Space
论文地址:https://arxiv.org/pdf/2412.06769
Coconut 涉及对传统 CoT 过程的简单修改:Coconut 不再通过语言模型头(language model head)和嵌入层将隐藏状态与语言 token 进行映射,而是直接将最后的隐藏状态(即连续思维)作为下一个 token 的输入嵌入(如图 1 所示)。
这种修改将推理从语言空间中解放出来,并且由于连续思维是完全可微的,因此可以通过梯度下降对系统进行端到端优化。为了增强潜在推理的训练,本文采用了多阶段训练策略,该策略有效地利用语言推理链来指导训练过程。
这种范式带来了高效的推理模式,与基于语言的推理不同,Coconut 中的连续思维可以同时编码多个潜在下一步,从而实现类似于 BFS(breadth-first search)的推理过程。尽管模型在初始阶段可能做出不正确的决策,但它可以在连续思维中保持许多可能的选项,并通过推理逐步排除错误路径,这一过程由一些隐含的价值函数引导。这种高级的推理机制超越了传统的 CoT,即使模型并没有显式地接受训练或指示以这种方式操作。
实验表明,Coconut 成功增强了 LLM 的推理能力。对于数学推理(GSM8k),使用连续思维被证明有利于提高推理准确率,这与语言推理链的效果相似。通过链接更多连续思维,可以扩展和解决日益具有挑战性的问题。
在逻辑推理方面,包括 ProntoQA 和本文新提出的 ProsQA,这需要更强的规划能力,Coconut 及其一些变体甚至超越了基于语言的 CoT 方法,同时在推理过程中生成的 token 明显更少。
这项研究在 X 上的讨论量非常高,其中单人转发的浏览量就高达 20 多万。
连续思维链:Coconut
方法概述。在 Coconut 方法中,LLM 在语言模式和潜在模式之间切换(图 1):
在语言模式下,该模型作为标准语言模型运行,自回归生成下一个 token。
在潜在模式下,它直接利用最后一个隐藏状态作为下一个输入嵌入。这个最后的隐藏状态代表当前的推理状态,称为连续思维。
特殊 token < bot >、< eot > 分别用于标记潜在思维模式的开始和结束。
训练。本文专注于问题 - 解决设置,其中模型接收问题作为输入,并通过推理过程生成答案。作者利用语言 CoT 数据来监督连续思维。如图 2 所示,在初始阶段,模型在常规 CoT 实例上进行训练。在后续阶段,即第 k 阶段,CoT 中的前 k 个推理步骤被替换为 k × c 个连续思维,其中 c 是一个超参数,用于控制替换单个语言推理步骤的潜在思维的数量。
推理过程。Coconut 的推理过程类似于标准的语言模型解码过程,不同之处在于,在潜在模式下,本文直接将最后一个隐藏状态作为下一个输入嵌入。这样做面临的挑战是确定何时在潜在模式和语言模式之间切换。当专注于问题 - 解决设置时,本文会在问题 token 后立即插入一个 < bot >token。对于 < eot >,作者考虑两种潜在策略:a) 在潜在思维上训练二元分类器,使模型能够自主决定何时终止潜在推理,或 b) 始终将潜在思维填充到恒定长度。本文发现这两种方法效果都相当好。除非另有说明,本文在实验中使用第二种选项以简化操作。
实验
研究团队通过三个数据集验证了大语言模型在连续潜空间中进行推理的可行性。实验主要评估模型生成答案的准确性和推理效率。
实验涉及两类主要任务:数学推理和逻辑推理。数学推理使用 GSM8k 数据集。逻辑推理则采用了两个数据集:5-hop ProntoQA 与该团队自行开发的 ProsQA。
ProntoQA 给出一个层级分类的知识结构,要求模型判断不同类别之间的从属关系是否正确。而 ProsQA 中是更具挑战性的推理任务,包含许多随机生成的有向无环图,要求模型进行大量规划和搜索。
实验设置
在实验设置方面,研究采用预训练的 GPT-2 模型,学习率为 1×10^−4,批量大小为 128。
对于数学推理任务,每个推理步骤使用 2 个潜在思维向量表示,整个训练过程分为 4 个渐进式阶段。
在逻辑推理任务中,每步使用 1 个潜在思维向量,训练分为 7 个渐进式阶段,逐步增加难度。所有实验均在标准训练流程后继续训练至第 50 轮,并通过在验证集上评估准确率来选择性能最佳的模型检查点用于最终测试。
基线方法和各种版本的 Coconut
为了全面评估方法效果,研究团队设置了以下基线方法进行对比:
1. 传统的 CoT:使用完整的思维链进行训练,让模型生成每一步的推理过程
2. No-CoT:模型直接生成最终答案,不要求中间推理步骤
3. iCoT:采用渐进式策略,逐步移除推理链中的步骤
4. Pause token:在问题和答案之间插入特殊的暂停 token
同时,他们还评估了 Coconut 的三个变体版本:
1. 无课程学习版本:跳过渐进训练,直接采用最终阶段的训练方式
2. 无思维版本:移除连续思维表示,仅保留分阶段训练机制
3. 思维替换版本:用特殊 token 替代连续思维的表示方式
结果与讨论