diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index e88962f85..1addea1d4 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -134,12 +134,16 @@ class InferenceEngine: def generate( self, + prompts: List[str] = None, + prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, generation_config: GenerationConfig = None, ) -> List[str]: """ Executing the inference step. Args: + prompts (Union[List[str], optional): Input prompts. Defaults to None. + prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None. Returns: @@ -147,13 +151,23 @@ class InferenceEngine: """ self.generation_config = generation_config + if prompts is not None or prompts_token_ids is not None: + self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids) - output_list = [] + output_seqs_list = [] + output_tokens_list = [] while self.request_handler.check_unfinished_seqs(): - output_list += self.step() + output_seqs_list += self.step() - return output_list + output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) + + for seq in output_seqs_list: + output_tokens_list.append(seq.input_token_id + seq.output_token_id) + + output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True) + + return output_str def add_request( self, @@ -235,7 +249,6 @@ class InferenceEngine: List[str]: Decoded finished sequences generated by one step. """ - output_list = [] batch = self.request_handler.schedule() # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. @@ -251,10 +264,4 @@ class InferenceEngine: self.request_handler.search_tokens(self.generation_config, logits) finished_sequences = self.request_handler.update() - # Decode completed sentences. - # TODO : update decoding step - for seq in finished_sequences: - output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True) - output_list.append(output_str) - - return output_list + return finished_sequences diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index 267e56231..780c08891 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -141,8 +141,7 @@ def benchmark_inference(args): with ctx: for _ in range(N_WARMUP_STEPS): if args.mode == "caiinference": - engine.add_request(prompts_token_ids=data) - engine.generate(generation_config) + engine.generate(prompts_token_ids=data, generation_config=generation_config) else: engine.generate(data, generation_config=generation_config) if args.profile: @@ -156,8 +155,7 @@ def benchmark_inference(args): whole_end2end = time.perf_counter() if args.mode == "caiinference": for _ in range(args.batch_size // mbsz): - engine.add_request(prompts_token_ids=data) - engine.generate(generation_config) + engine.generate(prompts_token_ids=data, generation_config=generation_config) else: for _ in range(args.batch_size // mbsz): engine.generate(data, generation_config=generation_config) diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 49bbe6df3..8c8e864b0 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -49,7 +49,7 @@ def check_inference_engine(test_cai=False): inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) - outputs = inference_engine.generate(generation_config) + outputs = inference_engine.generate(generation_config=generation_config) else: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id