From bdf9a001d61cfad4bb68752c4a808295165307a0 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Mon, 20 May 2024 22:49:18 +0800 Subject: [PATCH] [Fix/Inference] Add unsupported auto-policy error message (#5730) * [fix] auto policy error message * trivial --- colossalai/inference/core/engine.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 646b3cede..96c2b15ee 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -1,6 +1,6 @@ import time from itertools import count -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union import numpy as np import torch @@ -64,7 +64,7 @@ class InferenceEngine: tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], inference_config: InferenceConfig, verbose: bool = False, - model_policy: Policy = None, + model_policy: Union[Policy, Type[Policy]] = None, ) -> None: self.inference_config = inference_config self.dtype = inference_config.dtype @@ -105,7 +105,7 @@ class InferenceEngine: self._verify_args() - def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): + def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None): """ Shard model or/and Load weight @@ -150,11 +150,17 @@ class InferenceEngine: ) if model_policy is None: - if self.inference_config.pad_input: - model_type = "padding_" + self.model_config.model_type - else: - model_type = "nopadding_" + self.model_config.model_type - model_policy = model_policy_map[model_type]() + prefix = "nopadding" if not self.inference_config.pad_input else "padding" + model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" + model_policy = model_policy_map.get(model_policy_key) + + if not isinstance(model_policy, Policy): + try: + model_policy = model_policy() + except Exception as e: + raise ValueError(f"Unable to instantiate model policy: {e}") + + assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) tp_group = pg_mesh.get_group_along_axis(TP_AXIS)