|
|
|
@ -1,6 +1,6 @@
|
|
|
|
|
import time |
|
|
|
|
from itertools import count |
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
|
from typing import Dict, List, Optional, Tuple, Type, Union |
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
import torch |
|
|
|
@ -64,7 +64,7 @@ class InferenceEngine:
|
|
|
|
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], |
|
|
|
|
inference_config: InferenceConfig, |
|
|
|
|
verbose: bool = False, |
|
|
|
|
model_policy: Policy = None, |
|
|
|
|
model_policy: Union[Policy, Type[Policy]] = None, |
|
|
|
|
) -> None: |
|
|
|
|
self.inference_config = inference_config |
|
|
|
|
self.dtype = inference_config.dtype |
|
|
|
@ -105,7 +105,7 @@ class InferenceEngine:
|
|
|
|
|
|
|
|
|
|
self._verify_args() |
|
|
|
|
|
|
|
|
|
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): |
|
|
|
|
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None): |
|
|
|
|
""" |
|
|
|
|
Shard model or/and Load weight |
|
|
|
|
|
|
|
|
@ -150,11 +150,17 @@ class InferenceEngine:
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if model_policy is None: |
|
|
|
|
if self.inference_config.pad_input: |
|
|
|
|
model_type = "padding_" + self.model_config.model_type |
|
|
|
|
else: |
|
|
|
|
model_type = "nopadding_" + self.model_config.model_type |
|
|
|
|
model_policy = model_policy_map[model_type]() |
|
|
|
|
prefix = "nopadding" if not self.inference_config.pad_input else "padding" |
|
|
|
|
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" |
|
|
|
|
model_policy = model_policy_map.get(model_policy_key) |
|
|
|
|
|
|
|
|
|
if not isinstance(model_policy, Policy): |
|
|
|
|
try: |
|
|
|
|
model_policy = model_policy() |
|
|
|
|
except Exception as e: |
|
|
|
|
raise ValueError(f"Unable to instantiate model policy: {e}") |
|
|
|
|
|
|
|
|
|
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" |
|
|
|
|
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) |
|
|
|
|
tp_group = pg_mesh.get_group_along_axis(TP_AXIS) |
|
|
|
|
|
|
|
|
|