|
|
|
@ -57,6 +57,7 @@ class InferenceEngine:
|
|
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
self.generation_config = inference_config.to_generation_config(self.model_config) |
|
|
|
|
model = model.eval() |
|
|
|
|
model = model.cuda() |
|
|
|
|
model.to(self.dtype) |
|
|
|
|
|
|
|
|
|
if model_policy is None: |
|
|
|
@ -133,12 +134,13 @@ class InferenceEngine:
|
|
|
|
|
) |
|
|
|
|
shardformer = ShardFormer(shard_config=shardconfig) |
|
|
|
|
shard_model, _ = shardformer.optimize(model, model_policy) |
|
|
|
|
return shard_model.cuda() |
|
|
|
|
return shard_model |
|
|
|
|
|
|
|
|
|
def generate( |
|
|
|
|
self, |
|
|
|
|
prompts: List[str] = None, |
|
|
|
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, |
|
|
|
|
request_ids: List[int] = None, |
|
|
|
|
return_token_ids: bool = False, |
|
|
|
|
generation_config: Optional[GenerationConfig] = None, |
|
|
|
|
) -> List[str]: |
|
|
|
@ -148,6 +150,7 @@ class InferenceEngine:
|
|
|
|
|
Args: |
|
|
|
|
prompts (Union[List[str], optional): Input prompts. Defaults to None. |
|
|
|
|
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. |
|
|
|
|
request_ids (List[int], optional): The request ID. Defaults to None. |
|
|
|
|
return_token_ids (bool): Whether to return output token ids. Defaults to False. |
|
|
|
|
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None. |
|
|
|
|
|
|
|
|
@ -157,7 +160,7 @@ class InferenceEngine:
|
|
|
|
|
with torch.inference_mode(): |
|
|
|
|
self.generation_config = generation_config |
|
|
|
|
if prompts is not None or prompts_token_ids is not None: |
|
|
|
|
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids) |
|
|
|
|
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) |
|
|
|
|
|
|
|
|
|
output_seqs_list = [] |
|
|
|
|
total_tokens_list = [] |
|
|
|
@ -204,7 +207,7 @@ class InferenceEngine:
|
|
|
|
|
|
|
|
|
|
def add_request( |
|
|
|
|
self, |
|
|
|
|
requests_id: List[int] = None, |
|
|
|
|
request_ids: List[int] = None, |
|
|
|
|
prompts: List[str] = None, |
|
|
|
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, |
|
|
|
|
) -> None: |
|
|
|
@ -212,7 +215,7 @@ class InferenceEngine:
|
|
|
|
|
Add requests. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
requests_id (List[int], optional): The request ID. Defaults to None. |
|
|
|
|
request_ids (List[int], optional): The request ID. Defaults to None. |
|
|
|
|
prompts (Union[List[str], optional): Input prompts. Defaults to None. |
|
|
|
|
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. |
|
|
|
|
""" |
|
|
|
@ -223,6 +226,9 @@ class InferenceEngine:
|
|
|
|
|
|
|
|
|
|
block_size = self.inference_config.block_size |
|
|
|
|
|
|
|
|
|
if prompts is not None and not isinstance(prompts, list): |
|
|
|
|
prompts = [prompts] |
|
|
|
|
|
|
|
|
|
if prompts_token_ids is None: |
|
|
|
|
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." |
|
|
|
|
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ |
|
|
|
@ -245,8 +251,14 @@ class InferenceEngine:
|
|
|
|
|
prompts_num = len(prompts_token_ids) |
|
|
|
|
|
|
|
|
|
for i in range(prompts_num): |
|
|
|
|
if requests_id: |
|
|
|
|
request_id = requests_id[i] |
|
|
|
|
if request_ids: |
|
|
|
|
if not isinstance(request_ids, list): |
|
|
|
|
request_ids = [request_ids] |
|
|
|
|
assert isinstance( |
|
|
|
|
request_ids[0], int |
|
|
|
|
), f"The request_id type must be int, but got {type(request_ids[0])}" |
|
|
|
|
assert len(request_ids) == prompts_num |
|
|
|
|
request_id = request_ids[i] |
|
|
|
|
else: |
|
|
|
|
request_id = next(self.counter) |
|
|
|
|
if prompts == None: |
|
|
|
|