[Fix/Inference] Add unsupported auto-policy error message (#5730)

* [fix] auto policy error message

* trivial
pull/5737/head
Yuanheng Zhao 2024-05-20 22:49:18 +08:00 committed by GitHub
parent 283c407a19
commit bdf9a001d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 14 additions and 8 deletions

View File

@ -1,6 +1,6 @@
import time import time
from itertools import count from itertools import count
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Type, Union
import numpy as np import numpy as np
import torch import torch
@ -64,7 +64,7 @@ class InferenceEngine:
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig, inference_config: InferenceConfig,
verbose: bool = False, verbose: bool = False,
model_policy: Policy = None, model_policy: Union[Policy, Type[Policy]] = None,
) -> None: ) -> None:
self.inference_config = inference_config self.inference_config = inference_config
self.dtype = inference_config.dtype self.dtype = inference_config.dtype
@ -105,7 +105,7 @@ class InferenceEngine:
self._verify_args() self._verify_args()
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
""" """
Shard model or/and Load weight Shard model or/and Load weight
@ -150,11 +150,17 @@ class InferenceEngine:
) )
if model_policy is None: if model_policy is None:
if self.inference_config.pad_input: prefix = "nopadding" if not self.inference_config.pad_input else "padding"
model_type = "padding_" + self.model_config.model_type model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
else: model_policy = model_policy_map.get(model_policy_key)
model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]() if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS) tp_group = pg_mesh.get_group_along_axis(TP_AXIS)