|
|
|
@ -188,24 +188,6 @@ class BatchInfo:
|
|
|
|
|
if self.fd_inter_tensor is None: |
|
|
|
|
self.fd_inter_tensor = FDIntermTensors() |
|
|
|
|
|
|
|
|
|
def init_batch(self, seqs: List["Sequence"] = None): |
|
|
|
|
""" |
|
|
|
|
Initializes inference batches by input sentence list. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
seqs (List["Sequence"]): List of input sequence. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
if seqs is not None: |
|
|
|
|
if not isinstance(seqs, list): |
|
|
|
|
seqs = [seqs] |
|
|
|
|
for seq in seqs: |
|
|
|
|
if seq in self.sequences_set: |
|
|
|
|
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
self.sequences_set.add(seq) |
|
|
|
|
|
|
|
|
|
def init_fd_tensors(self): |
|
|
|
|
if not self.fd_inter_tensor.is_initialized: |
|
|
|
|
self.fd_inter_tensor.initialize( |
|
|
|
@ -273,19 +255,19 @@ class BatchInfo:
|
|
|
|
|
self.sequences_set.discard(seq) |
|
|
|
|
return seq |
|
|
|
|
|
|
|
|
|
def add_seqs(self, seqs: List["Sequence"]) -> None: |
|
|
|
|
def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None: |
|
|
|
|
""" |
|
|
|
|
Add new sequence to batch |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
seqs (List["Sequence"]): The list of new sequences. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
if not isinstance(seqs, list): |
|
|
|
|
# covnert single sequence to list |
|
|
|
|
if isinstance(seqs, Sequence): |
|
|
|
|
seqs = [seqs] |
|
|
|
|
|
|
|
|
|
for seq in seqs: |
|
|
|
|
if self.sequences_set and seq in self.sequences_set: |
|
|
|
|
if seq in self.sequences_set: |
|
|
|
|
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") |
|
|
|
|
continue |
|
|
|
|
self.sequences_set.add(seq) |
|
|
|
|