fix bugs in request_handler

pull/5258/head
yuehuayingxueluo 2024-01-02 13:02:20 +08:00 committed by FrankLeeeee
parent 62fd08ee44
commit 62968588d1
5 changed files with 21 additions and 13 deletions

View File

@ -226,12 +226,15 @@ class InferenceEngine:
self.v_cache,
)
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()
print("finished_sequences: ", finished_sequences)
# Decode completed sentences.
for seq in finished_sequences:
print("seq.output_token_id: ", seq.output_token_id)
if seq.prompt:
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
output_list.append(seq.prompt + output_str)
@ -239,4 +242,6 @@ class InferenceEngine:
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
output_list.append(output_str)
print("len(output_list): ", len(output_list))
return output_list

View File

@ -5,6 +5,7 @@ from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.struct import BatchInfo, Sequence
@ -179,10 +180,10 @@ class RequestHandler:
"""
# do logit processor
# NOTE: need to decide the granularity to process logits (sequence or batch)
# for type in ["top_p", "top_k", "min_p"]:
# config_dict = generation_config.to_dict()
# if type in config_dict:
# logits = logit_processor(type, logits, config_dict[type])
for type in ["top_p", "top_k", "min_p"]:
config_dict = generation_config.to_dict()
if type in config_dict:
logits = logit_processor(type, logits, config_dict[type])
torch.cuda.synchronize()
@ -207,11 +208,12 @@ class RequestHandler:
self.running_list.prefill.clear()
self.prefill_batch.clear_batch()
for seq in self.running_batch.sequences_set:
if seq.check_finish():
self.done_list.append(seq)
self.running_list.remove(seq)
self.running_batch.sequences_set.remove(seq)
self.cache_manager.free_block_table(seq.block_table)
finish_seqs = self.running_batch.fliter_batch()
return self.done_list
for seq in finish_seqs:
self.running_list.remove(seq)
self.cache_manager.free_block_table(seq.block_table)
self.done_list.extend(finish_seqs)
return finish_seqs

View File

@ -106,7 +106,7 @@ class Sequence:
return True
if self.output_token_id:
if self.output_token_id[-1] >= self.eos_token_id or len(self.output_token_id) == self.max_output_len:
if self.output_token_id[-1] == self.eos_token_id or self.output_len >= self.max_output_len:
self.status = RequestStatus.COMPLETED
return True

View File

@ -28,6 +28,7 @@ def check_inference_engine():
generation_config = GenerationConfig(top_k=2, top_p=0.8, do_sample=True)
outputs = inference_engine.generate(generation_config)
print("len(outputs): ", len(outputs))
print("outputs: ", outputs)
# Engine still gets some bug