mirror of https://github.com/hpcaitech/ColossalAI
update best answer function
parent
30a9443132
commit
6c619c9992
|
@ -120,17 +120,16 @@ class MCTS(BaseModel):
|
|||
self.back_propagation(child)
|
||||
|
||||
return self.get_best_answer()
|
||||
|
||||
def get_best_answer(self):
|
||||
|
||||
def _iter_nodes(self):
|
||||
to_visit = deque([self.root])
|
||||
best_node = self.root
|
||||
|
||||
while to_visit:
|
||||
current_node = to_visit.popleft()
|
||||
if current_node.Q > best_node.Q:
|
||||
best_node = current_node
|
||||
yield current_node
|
||||
to_visit.extend(current_node.children)
|
||||
|
||||
|
||||
def get_best_answer(self):
|
||||
best_node = max(self._iter_nodes(), key=lambda node: node.Q, default=self.root)
|
||||
return best_node.answer
|
||||
|
||||
def self_refine(self, node: MCTSNode):
|
||||
|
|
Loading…
Reference in New Issue