[inference] removed redundancy init_batch (#5353)

pull/5354/head
Frank Lee 10 months ago committed by GitHub
parent 249644c23b
commit db1a763307
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -171,7 +171,7 @@ class RequestHandler:
if self.running_list.ready_for_prefill(): if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill: for seq in self.running_list.prefill:
seq.mark_running() seq.mark_running()
self.prefill_batch.init_batch(self.running_list.prefill) self.prefill_batch.add_seqs(self.running_list.prefill)
return self.prefill_batch return self.prefill_batch
if not self.running_batch.is_empty: if not self.running_batch.is_empty:

@ -188,24 +188,6 @@ class BatchInfo:
if self.fd_inter_tensor is None: if self.fd_inter_tensor is None:
self.fd_inter_tensor = FDIntermTensors() self.fd_inter_tensor = FDIntermTensors()
def init_batch(self, seqs: List["Sequence"] = None):
"""
Initializes inference batches by input sentence list.
Args:
seqs (List["Sequence"]): List of input sequence.
"""
if seqs is not None:
if not isinstance(seqs, list):
seqs = [seqs]
for seq in seqs:
if seq in self.sequences_set:
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue
self.sequences_set.add(seq)
def init_fd_tensors(self): def init_fd_tensors(self):
if not self.fd_inter_tensor.is_initialized: if not self.fd_inter_tensor.is_initialized:
self.fd_inter_tensor.initialize( self.fd_inter_tensor.initialize(
@ -273,19 +255,19 @@ class BatchInfo:
self.sequences_set.discard(seq) self.sequences_set.discard(seq)
return seq return seq
def add_seqs(self, seqs: List["Sequence"]) -> None: def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None:
""" """
Add new sequence to batch Add new sequence to batch
Args: Args:
seqs (List["Sequence"]): The list of new sequences. seqs (List["Sequence"]): The list of new sequences.
""" """
# covnert single sequence to list
if not isinstance(seqs, list): if isinstance(seqs, Sequence):
seqs = [seqs] seqs = [seqs]
for seq in seqs: for seq in seqs:
if self.sequences_set and seq in self.sequences_set: if seq in self.sequences_set:
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue continue
self.sequences_set.add(seq) self.sequences_set.add(seq)

@ -60,9 +60,8 @@ def check_config_and_inference():
num_heads=2, num_heads=2,
head_dim=128, head_dim=128,
) )
batch.init_batch([sequence])
batch.add_seqs([sequence2, sequence3])
batch.add_seqs([sequence]) batch.add_seqs([sequence])
batch.add_seqs([sequence2, sequence3])
assert batch.is_empty == False assert batch.is_empty == False
assert batch.get_batch_size() == 3 assert batch.get_batch_size() == 3

Loading…
Cancel
Save