diff --git a/applications/ColossalChat/README.md b/applications/ColossalChat/README.md
index 100cc5ece..ef904b864 100755
--- a/applications/ColossalChat/README.md
+++ b/applications/ColossalChat/README.md
@@ -27,11 +27,11 @@
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
+- [O1 Journey](#o1-journey)
+ - [Inference with Self-refined MCTS](#inference-with-self-refined-mcts)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
-- [The Plan](#the-plan)
- - [Real-time progress](#real-time-progress)
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
- [Quick Preview](#quick-preview)
- [Authors](#authors)
@@ -272,7 +272,7 @@ Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pd
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
-### Inference Quantization and Serving - After Training
+## Inference Quantization and Serving - After Training
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
@@ -281,6 +281,21 @@ We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inferen
Online inference server scripts can help you deploy your own services.
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
+## O1 Journey
+### Inference with Self-refined MCTS
+We provide the implementation of MCT Self-Refine (MCTSr) algorithm, an innovative integration of Large Language Models with Monte Carlo Tree Search.
+To run inference with MCTS, simply use the following script.
+```python
+from coati.reasoner.guided_search.mcts import MCTS
+from coati.reasoner.guided_search.prompt_store.qwen import Qwen32B_prompt_CFG
+
+problem = "How Many R in 'Strawberry'"
+
+search_tree = MCTS(problem=problem, max_simulations=8, cfg=Qwen32B_prompt_CFG)
+answer = search_tree.simulate()
+print(answer)
+```
+
## Coati7B examples
### Generation
diff --git a/applications/ColossalChat/coati/reasoner/guided_search/llm.py b/applications/ColossalChat/coati/reasoner/guided_search/llm.py
new file mode 100644
index 000000000..5025a98ea
--- /dev/null
+++ b/applications/ColossalChat/coati/reasoner/guided_search/llm.py
@@ -0,0 +1,26 @@
+import openai
+from openai.types.chat.chat_completion import ChatCompletion
+from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
+
+API_KEY = "Dummy API Key"
+
+
+def get_client(base_url: str | None = None) -> openai.Client:
+ return openai.Client(api_key=API_KEY, base_url=base_url)
+
+
+def chat_completion(
+ messages: list[ChatCompletionMessageParam],
+ model: str,
+ base_url: str | None = None,
+ temperature: float = 0.8,
+ **kwargs,
+) -> ChatCompletion:
+ client = get_client(base_url)
+ response = client.chat.completions.create(
+ model=model,
+ messages=messages,
+ temperature=temperature,
+ **kwargs,
+ )
+ return response
diff --git a/applications/ColossalChat/coati/reasoner/guided_search/mcts.py b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py
new file mode 100644
index 000000000..693e2b750
--- /dev/null
+++ b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py
@@ -0,0 +1,250 @@
+"""
+Implementation of MCTS + Self-refine algorithm.
+
+Reference:
+1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte
+Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report"
+2. https://github.com/BrendanGraham14/mcts-llm/
+3. https://github.com/trotsky1997/MathBlackBox/
+4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py
+"""
+
+from __future__ import annotations
+
+import math
+from collections import deque
+
+import numpy as np
+import tqdm
+from coati.reasoner.guided_search.llm import chat_completion
+from coati.reasoner.guided_search.prompt_store.base import PromptCFG
+from pydantic import BaseModel
+
+
+class MCTSNode(BaseModel):
+ """
+ Node for MCTS.
+ """
+
+ answer: str
+ parent: MCTSNode = None
+ children: list[MCTSNode] = []
+ num_visits: int = 0
+ Q: int = 0
+ rewards: list[int] = []
+
+ def expand_node(self, node) -> None:
+ self.children.append(node)
+
+ def add_reward(self, reward: int) -> None:
+ self.rewards.append(reward)
+ self.Q = (np.min(self.rewards) + np.mean(self.rewards)) / 2
+
+
+class MCTS(BaseModel):
+ """
+ Simulation of MCTS process.
+ """
+
+ problem: str
+ max_simulations: int
+ cfg: PromptCFG
+ C: float = 1.4
+ max_children: int = 2
+ epsilon: float = 1e-5
+ root: MCTSNode = None
+
+ def initialization(self):
+ """
+ Root Initiation.
+ """
+ # Dummy answer as root.
+ base_answer = self.sample_base_answer()
+ self.root = MCTSNode(answer=base_answer)
+ self.self_evaluate(self.root)
+
+ def is_fully_expanded(self, node: MCTSNode):
+ return len(node.children) >= self.max_children or any(child.Q > node.Q for child in node.children)
+
+ def select_node(self) -> MCTSNode:
+ """
+ Select next node to explore.
+ """
+ candidates: list[MCTSNode] = []
+ to_explore = deque([self.root])
+
+ while to_explore:
+ current_node = to_explore.popleft()
+ if not self.is_fully_expanded(current_node):
+ candidates.append(current_node)
+ to_explore.extend(current_node.children)
+
+ if not candidates:
+ return self.root
+
+ return max(candidates, key=self.compute_uct)
+
+ def self_evaluate(self, node: MCTSNode):
+ """
+ Sample reward of the answer.
+ """
+ reward = self.sample_reward(node)
+ node.add_reward(reward)
+
+ def back_propagation(self, node: MCTSNode):
+ """
+ Back propagate the value of the refined answer.
+ """
+ parent = node.parent
+ while parent:
+ best_child_Q = max(child.Q for child in parent.children)
+ parent.Q = (parent.Q + best_child_Q) / 2
+ parent.num_visits += 1
+ parent = parent.parent
+
+ def compute_uct(self, node: MCTSNode):
+ """
+ Compute UCT.
+ """
+ if node.parent is None:
+ return -100
+ return node.Q + self.C * math.sqrt(math.log(node.parent.num_visits + 1) / (node.num_visits + self.epsilon))
+
+ def simulate(self):
+ self.initialization()
+ for _ in tqdm.tqdm(range(self.max_simulations)):
+ node = self.select_node()
+ child = self.self_refine(node)
+ node.expand_node(child)
+ self.self_evaluate(child)
+ self.back_propagation(child)
+
+ return self.get_best_answer()
+
+ def get_best_answer(self):
+ to_visit = deque([self.root])
+ best_node = self.root
+
+ while to_visit:
+ current_node = to_visit.popleft()
+ if current_node.Q > best_node.Q:
+ best_node = current_node
+ to_visit.extend(current_node.children)
+
+ return best_node.answer
+
+ def self_refine(self, node: MCTSNode):
+ """
+ Refine node.
+ """
+ critique_response = chat_completion(
+ messages=[
+ {
+ "role": "system",
+ "content": self.cfg.critic_system_prompt,
+ },
+ {
+ "role": "user",
+ "content": "\n\n".join(
+ [
+ f"\n{self.problem}\n",
+ f"\n{node.answer}\n",
+ ]
+ ),
+ },
+ ],
+ model=self.cfg.model,
+ base_url=self.cfg.base_url,
+ max_tokens=self.cfg.max_tokens,
+ )
+ critique = critique_response.choices[0].message.content
+ assert critique is not None
+ refined_answer_response = chat_completion(
+ messages=[
+ {
+ "role": "system",
+ "content": self.cfg.refine_system_prompt,
+ },
+ {
+ "role": "user",
+ "content": "\n\n".join(
+ [
+ f"\n{self.problem}\n",
+ f"\n{node.answer}\n",
+ f"\n{critique}\n",
+ ]
+ ),
+ },
+ ],
+ model=self.cfg.model,
+ base_url=self.cfg.base_url,
+ max_tokens=self.cfg.max_tokens,
+ )
+ refined_answer = refined_answer_response.choices[0].message.content
+ assert refined_answer is not None
+
+ return MCTSNode(answer=refined_answer, parent=node)
+
+ def sample_base_answer(self):
+ response = chat_completion(
+ messages=[
+ {
+ "role": "system",
+ "content": "The user will provide a problem. Solve the problem. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. \nThe answer is [answer] \n#### [answer].",
+ },
+ {
+ "role": "user",
+ "content": f"\n {self.problem} \n \nLet's think step by step",
+ },
+ ],
+ model=self.cfg.model,
+ base_url=self.cfg.base_url,
+ max_tokens=self.cfg.max_tokens,
+ )
+ assert response.choices[0].message.content is not None
+ return response.choices[0].message.content
+
+ def sample_reward(self, node: MCTSNode):
+ """
+ Calculate reward.
+ """
+ messages = [
+ {
+ "role": "system",
+ "content": self.cfg.evaluate_system_prompt,
+ },
+ {
+ "role": "user",
+ "content": "\n\n".join(
+ [
+ f"\n{self.problem}\n",
+ f"\n{node.answer}\n",
+ ]
+ ),
+ },
+ ]
+ for attempt in range(3):
+ try:
+ response = chat_completion(
+ messages=messages,
+ model=self.cfg.model,
+ base_url=self.cfg.base_url,
+ max_tokens=self.cfg.max_tokens,
+ )
+ assert response.choices[0].message.content is not None
+ return int(response.choices[0].message.content)
+ except ValueError:
+ messages.extend(
+ [
+ {
+ "role": "assistant",
+ "content": response.choices[0].message.content,
+ },
+ {
+ "role": "user",
+ "content": "Failed to parse reward as an integer.",
+ },
+ ]
+ )
+ if attempt == 2:
+ raise
diff --git a/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py
new file mode 100644
index 000000000..b325b8fa2
--- /dev/null
+++ b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py
@@ -0,0 +1,10 @@
+from pydantic import BaseModel
+
+
+class PromptCFG(BaseModel):
+ model: str
+ base_url: str
+ max_tokens: int = 4096
+ critic_system_prompt: str
+ refine_system_prompt: str
+ evaluate_system_prompt: str
diff --git a/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py
new file mode 100644
index 000000000..8bf0fa959
--- /dev/null
+++ b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py
@@ -0,0 +1,20 @@
+"""
+Prompts for Qwen Series.
+"""
+
+from coati.reasoner.guided_search.prompt_store.base import PromptCFG
+
+Qwen32B_prompt_CFG = PromptCFG(
+ base_url="http://0.0.0.0:8008/v1",
+ model="Qwen2.5-32B-Instruct",
+ critic_system_prompt="Provide a detailed and constructive critique to improve the answer. "
+ "Highlight specific areas that need refinement or correction.",
+ refine_system_prompt="""# Instruction
+ Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer].
+ """,
+ evaluate_system_prompt=(
+ "Analyze this answer strictly and critic, provide a reward score between -100 and 100 for the answer quality, using very strict standards. "
+ "Do not give a full score above 95. Make sure the reward score is an integer. "
+ "Return *ONLY* the score."
+ ),
+)