mirror of https://github.com/hpcaitech/ColossalAI
[inference] removed redundancy init_batch (#5353)
parent
249644c23b
commit
db1a763307
|
@ -171,7 +171,7 @@ class RequestHandler:
|
|||
if self.running_list.ready_for_prefill():
|
||||
for seq in self.running_list.prefill:
|
||||
seq.mark_running()
|
||||
self.prefill_batch.init_batch(self.running_list.prefill)
|
||||
self.prefill_batch.add_seqs(self.running_list.prefill)
|
||||
return self.prefill_batch
|
||||
|
||||
if not self.running_batch.is_empty:
|
||||
|
|
|
@ -188,24 +188,6 @@ class BatchInfo:
|
|||
if self.fd_inter_tensor is None:
|
||||
self.fd_inter_tensor = FDIntermTensors()
|
||||
|
||||
def init_batch(self, seqs: List["Sequence"] = None):
|
||||
"""
|
||||
Initializes inference batches by input sentence list.
|
||||
|
||||
Args:
|
||||
seqs (List["Sequence"]): List of input sequence.
|
||||
"""
|
||||
|
||||
if seqs is not None:
|
||||
if not isinstance(seqs, list):
|
||||
seqs = [seqs]
|
||||
for seq in seqs:
|
||||
if seq in self.sequences_set:
|
||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||
continue
|
||||
|
||||
self.sequences_set.add(seq)
|
||||
|
||||
def init_fd_tensors(self):
|
||||
if not self.fd_inter_tensor.is_initialized:
|
||||
self.fd_inter_tensor.initialize(
|
||||
|
@ -273,19 +255,19 @@ class BatchInfo:
|
|||
self.sequences_set.discard(seq)
|
||||
return seq
|
||||
|
||||
def add_seqs(self, seqs: List["Sequence"]) -> None:
|
||||
def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None:
|
||||
"""
|
||||
Add new sequence to batch
|
||||
|
||||
Args:
|
||||
seqs (List["Sequence"]): The list of new sequences.
|
||||
"""
|
||||
|
||||
if not isinstance(seqs, list):
|
||||
# covnert single sequence to list
|
||||
if isinstance(seqs, Sequence):
|
||||
seqs = [seqs]
|
||||
|
||||
for seq in seqs:
|
||||
if self.sequences_set and seq in self.sequences_set:
|
||||
if seq in self.sequences_set:
|
||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||
continue
|
||||
self.sequences_set.add(seq)
|
||||
|
|
|
@ -60,9 +60,8 @@ def check_config_and_inference():
|
|||
num_heads=2,
|
||||
head_dim=128,
|
||||
)
|
||||
batch.init_batch([sequence])
|
||||
batch.add_seqs([sequence2, sequence3])
|
||||
batch.add_seqs([sequence])
|
||||
batch.add_seqs([sequence2, sequence3])
|
||||
|
||||
assert batch.is_empty == False
|
||||
assert batch.get_batch_size() == 3
|
||||
|
|
Loading…
Reference in New Issue