Browse Source

[Fix/Inference] Add unsupported auto-policy error message (#5730)

* [fix] auto policy error message

* trivial
pull/5737/head
Yuanheng Zhao 6 months ago committed by GitHub
parent
commit
bdf9a001d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 22
      colossalai/inference/core/engine.py

22
colossalai/inference/core/engine.py

@ -1,6 +1,6 @@
import time
from itertools import count
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch
@ -64,7 +64,7 @@ class InferenceEngine:
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
verbose: bool = False,
model_policy: Policy = None,
model_policy: Union[Policy, Type[Policy]] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
@ -105,7 +105,7 @@ class InferenceEngine:
self._verify_args()
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
"""
Shard model or/and Load weight
@ -150,11 +150,17 @@ class InferenceEngine:
)
if model_policy is None:
if self.inference_config.pad_input:
model_type = "padding_" + self.model_config.model_type
else:
model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]()
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
model_policy = model_policy_map.get(model_policy_key)
if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)

Loading…
Cancel
Save