update best answer function

pull/6119/head
Tong Li 2024-11-08 03:30:21 +00:00
parent 30a9443132
commit 6c619c9992
1 changed files with 6 additions and 7 deletions

View File

@ -120,17 +120,16 @@ class MCTS(BaseModel):
self.back_propagation(child)
return self.get_best_answer()
def get_best_answer(self):
def _iter_nodes(self):
to_visit = deque([self.root])
best_node = self.root
while to_visit:
current_node = to_visit.popleft()
if current_node.Q > best_node.Q:
best_node = current_node
yield current_node
to_visit.extend(current_node.children)
def get_best_answer(self):
best_node = max(self._iter_nodes(), key=lambda node: node.Q, default=self.root)
return best_node.answer
def self_refine(self, node: MCTSNode):