From 89a9a600bc4802c912b0ed48d48f70bbcdd8142b Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 24 Oct 2024 17:51:19 +0800 Subject: [PATCH] [MCTS] Add self-refined MCTS (#6098) * add reasoner * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update code * delete llama * update prompts * update readme * update readme --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- applications/ColossalChat/README.md | 21 +- .../coati/reasoner/guided_search/llm.py | 26 ++ .../coati/reasoner/guided_search/mcts.py | 250 ++++++++++++++++++ .../guided_search/prompt_store/base.py | 10 + .../guided_search/prompt_store/qwen.py | 20 ++ 5 files changed, 324 insertions(+), 3 deletions(-) create mode 100644 applications/ColossalChat/coati/reasoner/guided_search/llm.py create mode 100644 applications/ColossalChat/coati/reasoner/guided_search/mcts.py create mode 100644 applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py create mode 100644 applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py diff --git a/applications/ColossalChat/README.md b/applications/ColossalChat/README.md index 100cc5ece..ef904b864 100755 --- a/applications/ColossalChat/README.md +++ b/applications/ColossalChat/README.md @@ -27,11 +27,11 @@ - [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo) - [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo) - [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto) +- [O1 Journey](#o1-journey) + - [Inference with Self-refined MCTS](#inference-with-self-refined-mcts) - [FAQ](#faq) - [How to save/load checkpoint](#faq) - [How to train with limited resources](#faq) -- [The Plan](#the-plan) - - [Real-time progress](#real-time-progress) - [Invitation to open-source contribution](#invitation-to-open-source-contribution) - [Quick Preview](#quick-preview) - [Authors](#authors) @@ -272,7 +272,7 @@ Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pd ## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO) We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information. -### Inference Quantization and Serving - After Training +## Inference Quantization and Serving - After Training We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models. @@ -281,6 +281,21 @@ We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inferen Online inference server scripts can help you deploy your own services. For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference). +## O1 Journey +### Inference with Self-refined MCTS +We provide the implementation of MCT Self-Refine (MCTSr) algorithm, an innovative integration of Large Language Models with Monte Carlo Tree Search. +To run inference with MCTS, simply use the following script. +```python +from coati.reasoner.guided_search.mcts import MCTS +from coati.reasoner.guided_search.prompt_store.qwen import Qwen32B_prompt_CFG + +problem = "How Many R in 'Strawberry'" + +search_tree = MCTS(problem=problem, max_simulations=8, cfg=Qwen32B_prompt_CFG) +answer = search_tree.simulate() +print(answer) +``` + ## Coati7B examples ### Generation diff --git a/applications/ColossalChat/coati/reasoner/guided_search/llm.py b/applications/ColossalChat/coati/reasoner/guided_search/llm.py new file mode 100644 index 000000000..5025a98ea --- /dev/null +++ b/applications/ColossalChat/coati/reasoner/guided_search/llm.py @@ -0,0 +1,26 @@ +import openai +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam + +API_KEY = "Dummy API Key" + + +def get_client(base_url: str | None = None) -> openai.Client: + return openai.Client(api_key=API_KEY, base_url=base_url) + + +def chat_completion( + messages: list[ChatCompletionMessageParam], + model: str, + base_url: str | None = None, + temperature: float = 0.8, + **kwargs, +) -> ChatCompletion: + client = get_client(base_url) + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + **kwargs, + ) + return response diff --git a/applications/ColossalChat/coati/reasoner/guided_search/mcts.py b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py new file mode 100644 index 000000000..693e2b750 --- /dev/null +++ b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py @@ -0,0 +1,250 @@ +""" +Implementation of MCTS + Self-refine algorithm. + +Reference: +1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte +Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report" +2. https://github.com/BrendanGraham14/mcts-llm/ +3. https://github.com/trotsky1997/MathBlackBox/ +4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py +""" + +from __future__ import annotations + +import math +from collections import deque + +import numpy as np +import tqdm +from coati.reasoner.guided_search.llm import chat_completion +from coati.reasoner.guided_search.prompt_store.base import PromptCFG +from pydantic import BaseModel + + +class MCTSNode(BaseModel): + """ + Node for MCTS. + """ + + answer: str + parent: MCTSNode = None + children: list[MCTSNode] = [] + num_visits: int = 0 + Q: int = 0 + rewards: list[int] = [] + + def expand_node(self, node) -> None: + self.children.append(node) + + def add_reward(self, reward: int) -> None: + self.rewards.append(reward) + self.Q = (np.min(self.rewards) + np.mean(self.rewards)) / 2 + + +class MCTS(BaseModel): + """ + Simulation of MCTS process. + """ + + problem: str + max_simulations: int + cfg: PromptCFG + C: float = 1.4 + max_children: int = 2 + epsilon: float = 1e-5 + root: MCTSNode = None + + def initialization(self): + """ + Root Initiation. + """ + # Dummy answer as root. + base_answer = self.sample_base_answer() + self.root = MCTSNode(answer=base_answer) + self.self_evaluate(self.root) + + def is_fully_expanded(self, node: MCTSNode): + return len(node.children) >= self.max_children or any(child.Q > node.Q for child in node.children) + + def select_node(self) -> MCTSNode: + """ + Select next node to explore. + """ + candidates: list[MCTSNode] = [] + to_explore = deque([self.root]) + + while to_explore: + current_node = to_explore.popleft() + if not self.is_fully_expanded(current_node): + candidates.append(current_node) + to_explore.extend(current_node.children) + + if not candidates: + return self.root + + return max(candidates, key=self.compute_uct) + + def self_evaluate(self, node: MCTSNode): + """ + Sample reward of the answer. + """ + reward = self.sample_reward(node) + node.add_reward(reward) + + def back_propagation(self, node: MCTSNode): + """ + Back propagate the value of the refined answer. + """ + parent = node.parent + while parent: + best_child_Q = max(child.Q for child in parent.children) + parent.Q = (parent.Q + best_child_Q) / 2 + parent.num_visits += 1 + parent = parent.parent + + def compute_uct(self, node: MCTSNode): + """ + Compute UCT. + """ + if node.parent is None: + return -100 + return node.Q + self.C * math.sqrt(math.log(node.parent.num_visits + 1) / (node.num_visits + self.epsilon)) + + def simulate(self): + self.initialization() + for _ in tqdm.tqdm(range(self.max_simulations)): + node = self.select_node() + child = self.self_refine(node) + node.expand_node(child) + self.self_evaluate(child) + self.back_propagation(child) + + return self.get_best_answer() + + def get_best_answer(self): + to_visit = deque([self.root]) + best_node = self.root + + while to_visit: + current_node = to_visit.popleft() + if current_node.Q > best_node.Q: + best_node = current_node + to_visit.extend(current_node.children) + + return best_node.answer + + def self_refine(self, node: MCTSNode): + """ + Refine node. + """ + critique_response = chat_completion( + messages=[ + { + "role": "system", + "content": self.cfg.critic_system_prompt, + }, + { + "role": "user", + "content": "\n\n".join( + [ + f"\n{self.problem}\n", + f"\n{node.answer}\n", + ] + ), + }, + ], + model=self.cfg.model, + base_url=self.cfg.base_url, + max_tokens=self.cfg.max_tokens, + ) + critique = critique_response.choices[0].message.content + assert critique is not None + refined_answer_response = chat_completion( + messages=[ + { + "role": "system", + "content": self.cfg.refine_system_prompt, + }, + { + "role": "user", + "content": "\n\n".join( + [ + f"\n{self.problem}\n", + f"\n{node.answer}\n", + f"\n{critique}\n", + ] + ), + }, + ], + model=self.cfg.model, + base_url=self.cfg.base_url, + max_tokens=self.cfg.max_tokens, + ) + refined_answer = refined_answer_response.choices[0].message.content + assert refined_answer is not None + + return MCTSNode(answer=refined_answer, parent=node) + + def sample_base_answer(self): + response = chat_completion( + messages=[ + { + "role": "system", + "content": "The user will provide a problem. Solve the problem. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. \nThe answer is [answer] \n#### [answer].", + }, + { + "role": "user", + "content": f"\n {self.problem} \n \nLet's think step by step", + }, + ], + model=self.cfg.model, + base_url=self.cfg.base_url, + max_tokens=self.cfg.max_tokens, + ) + assert response.choices[0].message.content is not None + return response.choices[0].message.content + + def sample_reward(self, node: MCTSNode): + """ + Calculate reward. + """ + messages = [ + { + "role": "system", + "content": self.cfg.evaluate_system_prompt, + }, + { + "role": "user", + "content": "\n\n".join( + [ + f"\n{self.problem}\n", + f"\n{node.answer}\n", + ] + ), + }, + ] + for attempt in range(3): + try: + response = chat_completion( + messages=messages, + model=self.cfg.model, + base_url=self.cfg.base_url, + max_tokens=self.cfg.max_tokens, + ) + assert response.choices[0].message.content is not None + return int(response.choices[0].message.content) + except ValueError: + messages.extend( + [ + { + "role": "assistant", + "content": response.choices[0].message.content, + }, + { + "role": "user", + "content": "Failed to parse reward as an integer.", + }, + ] + ) + if attempt == 2: + raise diff --git a/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py new file mode 100644 index 000000000..b325b8fa2 --- /dev/null +++ b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + + +class PromptCFG(BaseModel): + model: str + base_url: str + max_tokens: int = 4096 + critic_system_prompt: str + refine_system_prompt: str + evaluate_system_prompt: str diff --git a/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py new file mode 100644 index 000000000..8bf0fa959 --- /dev/null +++ b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py @@ -0,0 +1,20 @@ +""" +Prompts for Qwen Series. +""" + +from coati.reasoner.guided_search.prompt_store.base import PromptCFG + +Qwen32B_prompt_CFG = PromptCFG( + base_url="http://0.0.0.0:8008/v1", + model="Qwen2.5-32B-Instruct", + critic_system_prompt="Provide a detailed and constructive critique to improve the answer. " + "Highlight specific areas that need refinement or correction.", + refine_system_prompt="""# Instruction + Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. + """, + evaluate_system_prompt=( + "Analyze this answer strictly and critic, provide a reward score between -100 and 100 for the answer quality, using very strict standards. " + "Do not give a full score above 95. Make sure the reward score is an integer. " + "Return *ONLY* the score." + ), +)