|
|
|
@ -33,7 +33,7 @@ class InferenceEngine:
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model (nn.Module): Path or nn.Module of this model.
|
|
|
|
|
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
|
|
|
|
|
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
|
|
|
|
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
|
|
|
|
verbose (bool): Determine whether or not to log the generation process.
|
|
|
|
|
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
|
|
|
@ -42,19 +42,20 @@ class InferenceEngine:
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
model: nn.Module,
|
|
|
|
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
|
|
|
|
inference_config: Optional["InferenceConfig"] = None,
|
|
|
|
|
tokenizer: [Union[PreTrainedTokenizer, PreTrainedTokenizerFast]],
|
|
|
|
|
inference_config: InferenceConfig,
|
|
|
|
|
verbose: bool = False,
|
|
|
|
|
model_policy: Policy = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
assert inference_config, "Please provide inference_config."
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
assert tokenizer, "Please provide a tokenizer, either a defined one or str"
|
|
|
|
|
self.inference_config = inference_config
|
|
|
|
|
self.model_config = model.config
|
|
|
|
|
self.device = torch.device("cuda")
|
|
|
|
|
self.dtype = inference_config.dtype
|
|
|
|
|
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
self.generation_config = inference_config.to_generation_config(self.model_config)
|
|
|
|
|
model = model.eval()
|
|
|
|
|
model.to(self.dtype)
|
|
|
|
|
|
|
|
|
@ -80,6 +81,8 @@ class InferenceEngine:
|
|
|
|
|
|
|
|
|
|
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
|
|
|
|
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
|
|
|
|
|
# DISCUSS maybe move this into batch info?
|
|
|
|
|
|
|
|
|
|
self.counter = count()
|
|
|
|
|
|
|
|
|
|
def _verify_config(self) -> None:
|
|
|
|
@ -137,7 +140,7 @@ class InferenceEngine:
|
|
|
|
|
self,
|
|
|
|
|
prompts: List[str] = None,
|
|
|
|
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
|
|
|
|
generation_config: GenerationConfig = None,
|
|
|
|
|
generation_config: Optional[GenerationConfig] = None,
|
|
|
|
|
) -> List[str]:
|
|
|
|
|
"""
|
|
|
|
|
Executing the inference step.
|
|
|
|
@ -158,6 +161,10 @@ class InferenceEngine:
|
|
|
|
|
output_seqs_list = []
|
|
|
|
|
output_tokens_list = []
|
|
|
|
|
|
|
|
|
|
# intuition: If user provide a generation config, we should replace the existing one.
|
|
|
|
|
if generation_config is not None:
|
|
|
|
|
self.generation_config = generation_config
|
|
|
|
|
|
|
|
|
|
while self.request_handler.check_unfinished_seqs():
|
|
|
|
|
output_seqs_list += self.step()
|
|
|
|
|
|
|
|
|
@ -285,8 +292,8 @@ class InferenceEngine:
|
|
|
|
|
|
|
|
|
|
if self.inference_config.pad_input:
|
|
|
|
|
logits = logits[:, -1, :]
|
|
|
|
|
|
|
|
|
|
self.request_handler.search_tokens(self.generation_config, logits)
|
|
|
|
|
|
|
|
|
|
finished_sequences = self.request_handler.update()
|
|
|
|
|
|
|
|
|
|
return finished_sequences
|
|
|
|
|