mirror of https://github.com/hpcaitech/ColossalAI
update best answer function
parent
30a9443132
commit
6c619c9992
|
@ -120,17 +120,16 @@ class MCTS(BaseModel):
|
||||||
self.back_propagation(child)
|
self.back_propagation(child)
|
||||||
|
|
||||||
return self.get_best_answer()
|
return self.get_best_answer()
|
||||||
|
|
||||||
def get_best_answer(self):
|
def _iter_nodes(self):
|
||||||
to_visit = deque([self.root])
|
to_visit = deque([self.root])
|
||||||
best_node = self.root
|
|
||||||
|
|
||||||
while to_visit:
|
while to_visit:
|
||||||
current_node = to_visit.popleft()
|
current_node = to_visit.popleft()
|
||||||
if current_node.Q > best_node.Q:
|
yield current_node
|
||||||
best_node = current_node
|
|
||||||
to_visit.extend(current_node.children)
|
to_visit.extend(current_node.children)
|
||||||
|
|
||||||
|
def get_best_answer(self):
|
||||||
|
best_node = max(self._iter_nodes(), key=lambda node: node.Q, default=self.root)
|
||||||
return best_node.answer
|
return best_node.answer
|
||||||
|
|
||||||
def self_refine(self, node: MCTSNode):
|
def self_refine(self, node: MCTSNode):
|
||||||
|
|
Loading…
Reference in New Issue