diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 0f6705157..0dc03d4ae 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -226,12 +226,15 @@ class InferenceEngine: self.v_cache, ) + logits = logits[:, -1, :] self.request_handler.search_tokens(self.generation_config, logits) - finished_sequences = self.request_handler.update() + print("finished_sequences: ", finished_sequences) + # Decode completed sentences. for seq in finished_sequences: + print("seq.output_token_id: ", seq.output_token_id) if seq.prompt: output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True) output_list.append(seq.prompt + output_str) @@ -239,4 +242,6 @@ class InferenceEngine: output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True) output_list.append(output_str) + print("len(output_list): ", len(output_list)) + return output_list diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 3cc203470..e383640f7 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -5,6 +5,7 @@ from transformers.configuration_utils import PretrainedConfig from colossalai.inference.config import InferenceConfig from colossalai.inference.kv_cache import KVCacheManager +from colossalai.inference.logit_processors import logit_processor from colossalai.inference.sampler import * from colossalai.inference.struct import BatchInfo, Sequence @@ -179,10 +180,10 @@ class RequestHandler: """ # do logit processor # NOTE: need to decide the granularity to process logits (sequence or batch) - # for type in ["top_p", "top_k", "min_p"]: - # config_dict = generation_config.to_dict() - # if type in config_dict: - # logits = logit_processor(type, logits, config_dict[type]) + for type in ["top_p", "top_k", "min_p"]: + config_dict = generation_config.to_dict() + if type in config_dict: + logits = logit_processor(type, logits, config_dict[type]) torch.cuda.synchronize() @@ -207,11 +208,12 @@ class RequestHandler: self.running_list.prefill.clear() self.prefill_batch.clear_batch() - for seq in self.running_batch.sequences_set: - if seq.check_finish(): - self.done_list.append(seq) - self.running_list.remove(seq) - self.running_batch.sequences_set.remove(seq) - self.cache_manager.free_block_table(seq.block_table) + finish_seqs = self.running_batch.fliter_batch() - return self.done_list + for seq in finish_seqs: + self.running_list.remove(seq) + self.cache_manager.free_block_table(seq.block_table) + + self.done_list.extend(finish_seqs) + + return finish_seqs diff --git a/colossalai/inference/modeling/models/__init__.py b/colossalai/inference/modeling/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 6133008fe..6ea5d288c 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -106,7 +106,7 @@ class Sequence: return True if self.output_token_id: - if self.output_token_id[-1] >= self.eos_token_id or len(self.output_token_id) == self.max_output_len: + if self.output_token_id[-1] == self.eos_token_id or self.output_len >= self.max_output_len: self.status = RequestStatus.COMPLETED return True diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 26c9d5f96..d9b6b4089 100755 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -28,6 +28,7 @@ def check_inference_engine(): generation_config = GenerationConfig(top_k=2, top_p=0.8, do_sample=True) outputs = inference_engine.generate(generation_config) + print("len(outputs): ", len(outputs)) print("outputs: ", outputs) # Engine still gets some bug