mirror of https://github.com/hpcaitech/ColossalAI
[Fix/Inference] Add unsupported auto-policy error message (#5730)
* [fix] auto policy error message * trivialpull/5737/head
parent
283c407a19
commit
bdf9a001d6
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from itertools import count
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -64,7 +64,7 @@ class InferenceEngine:
|
|||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
inference_config: InferenceConfig,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy = None,
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.dtype = inference_config.dtype
|
||||
|
@ -105,7 +105,7 @@ class InferenceEngine:
|
|||
|
||||
self._verify_args()
|
||||
|
||||
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
|
||||
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
|
@ -150,11 +150,17 @@ class InferenceEngine:
|
|||
)
|
||||
|
||||
if model_policy is None:
|
||||
if self.inference_config.pad_input:
|
||||
model_type = "padding_" + self.model_config.model_type
|
||||
else:
|
||||
model_type = "nopadding_" + self.model_config.model_type
|
||||
model_policy = model_policy_map[model_type]()
|
||||
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
|
||||
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
|
||||
model_policy = model_policy_map.get(model_policy_key)
|
||||
|
||||
if not isinstance(model_policy, Policy):
|
||||
try:
|
||||
model_policy = model_policy()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to instantiate model policy: {e}")
|
||||
|
||||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
|
|
Loading…
Reference in New Issue