diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 585f87945..80d77d097 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -171,7 +171,7 @@ class RequestHandler: if self.running_list.ready_for_prefill(): for seq in self.running_list.prefill: seq.mark_running() - self.prefill_batch.init_batch(self.running_list.prefill) + self.prefill_batch.add_seqs(self.running_list.prefill) return self.prefill_batch if not self.running_batch.is_empty: diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 22b5b5a3a..766e54ab1 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -188,24 +188,6 @@ class BatchInfo: if self.fd_inter_tensor is None: self.fd_inter_tensor = FDIntermTensors() - def init_batch(self, seqs: List["Sequence"] = None): - """ - Initializes inference batches by input sentence list. - - Args: - seqs (List["Sequence"]): List of input sequence. - """ - - if seqs is not None: - if not isinstance(seqs, list): - seqs = [seqs] - for seq in seqs: - if seq in self.sequences_set: - logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") - continue - - self.sequences_set.add(seq) - def init_fd_tensors(self): if not self.fd_inter_tensor.is_initialized: self.fd_inter_tensor.initialize( @@ -273,19 +255,19 @@ class BatchInfo: self.sequences_set.discard(seq) return seq - def add_seqs(self, seqs: List["Sequence"]) -> None: + def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None: """ Add new sequence to batch Args: seqs (List["Sequence"]): The list of new sequences. """ - - if not isinstance(seqs, list): + # covnert single sequence to list + if isinstance(seqs, Sequence): seqs = [seqs] for seq in seqs: - if self.sequences_set and seq in self.sequences_set: + if seq in self.sequences_set: logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") continue self.sequences_set.add(seq) diff --git a/tests/test_infer/test_config_and_struct.py b/tests/test_infer/test_config_and_struct.py index 16f5bcc7f..e0736518c 100755 --- a/tests/test_infer/test_config_and_struct.py +++ b/tests/test_infer/test_config_and_struct.py @@ -60,9 +60,8 @@ def check_config_and_inference(): num_heads=2, head_dim=128, ) - batch.init_batch([sequence]) - batch.add_seqs([sequence2, sequence3]) batch.add_seqs([sequence]) + batch.add_seqs([sequence2, sequence3]) assert batch.is_empty == False assert batch.get_batch_size() == 3