mirror of https://github.com/hpcaitech/ColossalAI
[MCTS] Add self-refined MCTS (#6098)
* add reasoner * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update code * delete llama * update prompts * update readme * update readme --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6106/head
parent
4294ae83bb
commit
89a9a600bc
|
@ -27,11 +27,11 @@
|
||||||
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
|
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
|
||||||
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
|
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
|
||||||
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
|
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
|
||||||
|
- [O1 Journey](#o1-journey)
|
||||||
|
- [Inference with Self-refined MCTS](#inference-with-self-refined-mcts)
|
||||||
- [FAQ](#faq)
|
- [FAQ](#faq)
|
||||||
- [How to save/load checkpoint](#faq)
|
- [How to save/load checkpoint](#faq)
|
||||||
- [How to train with limited resources](#faq)
|
- [How to train with limited resources](#faq)
|
||||||
- [The Plan](#the-plan)
|
|
||||||
- [Real-time progress](#real-time-progress)
|
|
||||||
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
|
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
|
||||||
- [Quick Preview](#quick-preview)
|
- [Quick Preview](#quick-preview)
|
||||||
- [Authors](#authors)
|
- [Authors](#authors)
|
||||||
|
@ -272,7 +272,7 @@ Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pd
|
||||||
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
|
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
|
||||||
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
|
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
|
||||||
|
|
||||||
### Inference Quantization and Serving - After Training
|
## Inference Quantization and Serving - After Training
|
||||||
|
|
||||||
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
|
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
|
||||||
|
|
||||||
|
@ -281,6 +281,21 @@ We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inferen
|
||||||
Online inference server scripts can help you deploy your own services.
|
Online inference server scripts can help you deploy your own services.
|
||||||
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
|
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
|
||||||
|
|
||||||
|
## O1 Journey
|
||||||
|
### Inference with Self-refined MCTS
|
||||||
|
We provide the implementation of MCT Self-Refine (MCTSr) algorithm, an innovative integration of Large Language Models with Monte Carlo Tree Search.
|
||||||
|
To run inference with MCTS, simply use the following script.
|
||||||
|
```python
|
||||||
|
from coati.reasoner.guided_search.mcts import MCTS
|
||||||
|
from coati.reasoner.guided_search.prompt_store.qwen import Qwen32B_prompt_CFG
|
||||||
|
|
||||||
|
problem = "How Many R in 'Strawberry'"
|
||||||
|
|
||||||
|
search_tree = MCTS(problem=problem, max_simulations=8, cfg=Qwen32B_prompt_CFG)
|
||||||
|
answer = search_tree.simulate()
|
||||||
|
print(answer)
|
||||||
|
```
|
||||||
|
|
||||||
## Coati7B examples
|
## Coati7B examples
|
||||||
|
|
||||||
### Generation
|
### Generation
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
import openai
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
API_KEY = "Dummy API Key"
|
||||||
|
|
||||||
|
|
||||||
|
def get_client(base_url: str | None = None) -> openai.Client:
|
||||||
|
return openai.Client(api_key=API_KEY, base_url=base_url)
|
||||||
|
|
||||||
|
|
||||||
|
def chat_completion(
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
model: str,
|
||||||
|
base_url: str | None = None,
|
||||||
|
temperature: float = 0.8,
|
||||||
|
**kwargs,
|
||||||
|
) -> ChatCompletion:
|
||||||
|
client = get_client(base_url)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return response
|
|
@ -0,0 +1,250 @@
|
||||||
|
"""
|
||||||
|
Implementation of MCTS + Self-refine algorithm.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte
|
||||||
|
Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report"
|
||||||
|
2. https://github.com/BrendanGraham14/mcts-llm/
|
||||||
|
3. https://github.com/trotsky1997/MathBlackBox/
|
||||||
|
4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
|
from coati.reasoner.guided_search.llm import chat_completion
|
||||||
|
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class MCTSNode(BaseModel):
|
||||||
|
"""
|
||||||
|
Node for MCTS.
|
||||||
|
"""
|
||||||
|
|
||||||
|
answer: str
|
||||||
|
parent: MCTSNode = None
|
||||||
|
children: list[MCTSNode] = []
|
||||||
|
num_visits: int = 0
|
||||||
|
Q: int = 0
|
||||||
|
rewards: list[int] = []
|
||||||
|
|
||||||
|
def expand_node(self, node) -> None:
|
||||||
|
self.children.append(node)
|
||||||
|
|
||||||
|
def add_reward(self, reward: int) -> None:
|
||||||
|
self.rewards.append(reward)
|
||||||
|
self.Q = (np.min(self.rewards) + np.mean(self.rewards)) / 2
|
||||||
|
|
||||||
|
|
||||||
|
class MCTS(BaseModel):
|
||||||
|
"""
|
||||||
|
Simulation of MCTS process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
problem: str
|
||||||
|
max_simulations: int
|
||||||
|
cfg: PromptCFG
|
||||||
|
C: float = 1.4
|
||||||
|
max_children: int = 2
|
||||||
|
epsilon: float = 1e-5
|
||||||
|
root: MCTSNode = None
|
||||||
|
|
||||||
|
def initialization(self):
|
||||||
|
"""
|
||||||
|
Root Initiation.
|
||||||
|
"""
|
||||||
|
# Dummy answer as root.
|
||||||
|
base_answer = self.sample_base_answer()
|
||||||
|
self.root = MCTSNode(answer=base_answer)
|
||||||
|
self.self_evaluate(self.root)
|
||||||
|
|
||||||
|
def is_fully_expanded(self, node: MCTSNode):
|
||||||
|
return len(node.children) >= self.max_children or any(child.Q > node.Q for child in node.children)
|
||||||
|
|
||||||
|
def select_node(self) -> MCTSNode:
|
||||||
|
"""
|
||||||
|
Select next node to explore.
|
||||||
|
"""
|
||||||
|
candidates: list[MCTSNode] = []
|
||||||
|
to_explore = deque([self.root])
|
||||||
|
|
||||||
|
while to_explore:
|
||||||
|
current_node = to_explore.popleft()
|
||||||
|
if not self.is_fully_expanded(current_node):
|
||||||
|
candidates.append(current_node)
|
||||||
|
to_explore.extend(current_node.children)
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return self.root
|
||||||
|
|
||||||
|
return max(candidates, key=self.compute_uct)
|
||||||
|
|
||||||
|
def self_evaluate(self, node: MCTSNode):
|
||||||
|
"""
|
||||||
|
Sample reward of the answer.
|
||||||
|
"""
|
||||||
|
reward = self.sample_reward(node)
|
||||||
|
node.add_reward(reward)
|
||||||
|
|
||||||
|
def back_propagation(self, node: MCTSNode):
|
||||||
|
"""
|
||||||
|
Back propagate the value of the refined answer.
|
||||||
|
"""
|
||||||
|
parent = node.parent
|
||||||
|
while parent:
|
||||||
|
best_child_Q = max(child.Q for child in parent.children)
|
||||||
|
parent.Q = (parent.Q + best_child_Q) / 2
|
||||||
|
parent.num_visits += 1
|
||||||
|
parent = parent.parent
|
||||||
|
|
||||||
|
def compute_uct(self, node: MCTSNode):
|
||||||
|
"""
|
||||||
|
Compute UCT.
|
||||||
|
"""
|
||||||
|
if node.parent is None:
|
||||||
|
return -100
|
||||||
|
return node.Q + self.C * math.sqrt(math.log(node.parent.num_visits + 1) / (node.num_visits + self.epsilon))
|
||||||
|
|
||||||
|
def simulate(self):
|
||||||
|
self.initialization()
|
||||||
|
for _ in tqdm.tqdm(range(self.max_simulations)):
|
||||||
|
node = self.select_node()
|
||||||
|
child = self.self_refine(node)
|
||||||
|
node.expand_node(child)
|
||||||
|
self.self_evaluate(child)
|
||||||
|
self.back_propagation(child)
|
||||||
|
|
||||||
|
return self.get_best_answer()
|
||||||
|
|
||||||
|
def get_best_answer(self):
|
||||||
|
to_visit = deque([self.root])
|
||||||
|
best_node = self.root
|
||||||
|
|
||||||
|
while to_visit:
|
||||||
|
current_node = to_visit.popleft()
|
||||||
|
if current_node.Q > best_node.Q:
|
||||||
|
best_node = current_node
|
||||||
|
to_visit.extend(current_node.children)
|
||||||
|
|
||||||
|
return best_node.answer
|
||||||
|
|
||||||
|
def self_refine(self, node: MCTSNode):
|
||||||
|
"""
|
||||||
|
Refine node.
|
||||||
|
"""
|
||||||
|
critique_response = chat_completion(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.cfg.critic_system_prompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "\n\n".join(
|
||||||
|
[
|
||||||
|
f"<problem>\n{self.problem}\n</problem>",
|
||||||
|
f"<current_answer>\n{node.answer}\n</current_answer>",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
model=self.cfg.model,
|
||||||
|
base_url=self.cfg.base_url,
|
||||||
|
max_tokens=self.cfg.max_tokens,
|
||||||
|
)
|
||||||
|
critique = critique_response.choices[0].message.content
|
||||||
|
assert critique is not None
|
||||||
|
refined_answer_response = chat_completion(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.cfg.refine_system_prompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "\n\n".join(
|
||||||
|
[
|
||||||
|
f"<problem>\n{self.problem}\n</problem>",
|
||||||
|
f"<current_answer>\n{node.answer}\n</current_answer>",
|
||||||
|
f"<critique>\n{critique}\n</critique>",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
model=self.cfg.model,
|
||||||
|
base_url=self.cfg.base_url,
|
||||||
|
max_tokens=self.cfg.max_tokens,
|
||||||
|
)
|
||||||
|
refined_answer = refined_answer_response.choices[0].message.content
|
||||||
|
assert refined_answer is not None
|
||||||
|
|
||||||
|
return MCTSNode(answer=refined_answer, parent=node)
|
||||||
|
|
||||||
|
def sample_base_answer(self):
|
||||||
|
response = chat_completion(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "The user will provide a problem. Solve the problem. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. \nThe answer is [answer] \n#### [answer].",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"<problem>\n {self.problem} \n</problem> \nLet's think step by step",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
model=self.cfg.model,
|
||||||
|
base_url=self.cfg.base_url,
|
||||||
|
max_tokens=self.cfg.max_tokens,
|
||||||
|
)
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
def sample_reward(self, node: MCTSNode):
|
||||||
|
"""
|
||||||
|
Calculate reward.
|
||||||
|
"""
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.cfg.evaluate_system_prompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "\n\n".join(
|
||||||
|
[
|
||||||
|
f"<problem>\n{self.problem}\n</problem>",
|
||||||
|
f"<answer>\n{node.answer}\n</answer>",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
for attempt in range(3):
|
||||||
|
try:
|
||||||
|
response = chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
model=self.cfg.model,
|
||||||
|
base_url=self.cfg.base_url,
|
||||||
|
max_tokens=self.cfg.max_tokens,
|
||||||
|
)
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
return int(response.choices[0].message.content)
|
||||||
|
except ValueError:
|
||||||
|
messages.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Failed to parse reward as an integer.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if attempt == 2:
|
||||||
|
raise
|
|
@ -0,0 +1,10 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class PromptCFG(BaseModel):
|
||||||
|
model: str
|
||||||
|
base_url: str
|
||||||
|
max_tokens: int = 4096
|
||||||
|
critic_system_prompt: str
|
||||||
|
refine_system_prompt: str
|
||||||
|
evaluate_system_prompt: str
|
|
@ -0,0 +1,20 @@
|
||||||
|
"""
|
||||||
|
Prompts for Qwen Series.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
|
||||||
|
|
||||||
|
Qwen32B_prompt_CFG = PromptCFG(
|
||||||
|
base_url="http://0.0.0.0:8008/v1",
|
||||||
|
model="Qwen2.5-32B-Instruct",
|
||||||
|
critic_system_prompt="Provide a detailed and constructive critique to improve the answer. "
|
||||||
|
"Highlight specific areas that need refinement or correction.",
|
||||||
|
refine_system_prompt="""# Instruction
|
||||||
|
Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer].
|
||||||
|
""",
|
||||||
|
evaluate_system_prompt=(
|
||||||
|
"Analyze this answer strictly and critic, provide a reward score between -100 and 100 for the answer quality, using very strict standards. "
|
||||||
|
"Do not give a full score above 95. Make sure the reward score is an integer. "
|
||||||
|
"Return *ONLY* the score."
|
||||||
|
),
|
||||||
|
)
|
Loading…
Reference in New Issue