From 393c8f5b7fc3051ac01efd8e23c8f6750d67f4c9 Mon Sep 17 00:00:00 2001 From: hugo-syn <61210734+hugo-syn@users.noreply.github.com> Date: Mon, 13 May 2024 15:06:44 +0200 Subject: [PATCH] [hotfix] fix inference typo (#5438) --- colossalai/legacy/inference/async_manager.py | 6 +++--- colossalai/legacy/inference/manager.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/colossalai/legacy/inference/async_manager.py b/colossalai/legacy/inference/async_manager.py index 60440a792..526e0f632 100644 --- a/colossalai/legacy/inference/async_manager.py +++ b/colossalai/legacy/inference/async_manager.py @@ -55,14 +55,14 @@ class Async_DynamicBatchManager(DynamicBatchManager): self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch has_new_finished, outputs = self._prefill_batch(self.running_batch) - self._filter_runing_batch() + self._filter_running_batch() self.has_wait_tokens = 0 else: if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) has_new_finished, outputs = self._decode_batch(self.running_batch) - self._filter_runing_batch() + self._filter_running_batch() self.has_wait_tokens += 1 else: @@ -78,7 +78,7 @@ class Async_DynamicBatchManager(DynamicBatchManager): else: self.stats_tool.count_output_tokens(self.running_batch) has_new_finished, outputs = self._decode_batch(self.running_batch) - self._filter_runing_batch() + self._filter_running_batch() self.has_wait_tokens += 1 if has_new_finished: diff --git a/colossalai/legacy/inference/manager.py b/colossalai/legacy/inference/manager.py index 9672a5014..050dc22b5 100644 --- a/colossalai/legacy/inference/manager.py +++ b/colossalai/legacy/inference/manager.py @@ -131,14 +131,14 @@ class DynamicBatchManager: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch yield from self._prefill_batch(self.running_batch) - self._filter_runing_batch() + self._filter_running_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) yield from self._decode_batch(self.running_batch) - self._filter_runing_batch() + self._filter_running_batch() self.has_wait_tokens += 1 return else: @@ -154,7 +154,7 @@ class DynamicBatchManager: else: self.stats_tool.count_output_tokens(self.running_batch) yield from self._decode_batch(self.running_batch) - self._filter_runing_batch() + self._filter_running_batch() self.has_wait_tokens += 1 return @@ -243,7 +243,7 @@ class DynamicBatchManager: self._filter_batch(batch) yield from self._output_process(finished_reqs) - def _filter_runing_batch(self): + def _filter_running_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): self.running_batch = None