mirror of https://github.com/hpcaitech/ColossalAI
[Fix/Inference] Add unsupported auto-policy error message (#5730)
* [fix] auto policy error message * trivialpull/5737/head
parent
283c407a19
commit
bdf9a001d6
|
@ -1,6 +1,6 @@
|
||||||
import time
|
import time
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -64,7 +64,7 @@ class InferenceEngine:
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
inference_config: InferenceConfig,
|
inference_config: InferenceConfig,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
model_policy: Policy = None,
|
model_policy: Union[Policy, Type[Policy]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.inference_config = inference_config
|
self.inference_config = inference_config
|
||||||
self.dtype = inference_config.dtype
|
self.dtype = inference_config.dtype
|
||||||
|
@ -105,7 +105,7 @@ class InferenceEngine:
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
|
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
|
||||||
"""
|
"""
|
||||||
Shard model or/and Load weight
|
Shard model or/and Load weight
|
||||||
|
|
||||||
|
@ -150,11 +150,17 @@ class InferenceEngine:
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_policy is None:
|
if model_policy is None:
|
||||||
if self.inference_config.pad_input:
|
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
|
||||||
model_type = "padding_" + self.model_config.model_type
|
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
|
||||||
else:
|
model_policy = model_policy_map.get(model_policy_key)
|
||||||
model_type = "nopadding_" + self.model_config.model_type
|
|
||||||
model_policy = model_policy_map[model_type]()
|
if not isinstance(model_policy, Policy):
|
||||||
|
try:
|
||||||
|
model_policy = model_policy()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Unable to instantiate model policy: {e}")
|
||||||
|
|
||||||
|
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
||||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue