diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 2a79d4a88..dfac7cfd9 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -32,7 +32,7 @@ Colossal Inference is composed of three main components: In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes. -![Colossal-Inference](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-arch.png) +Colossal-Inference ## Roadmap of our implementation diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index c46934fb0..b01896e48 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -1,10 +1,8 @@ import importlib from dataclasses import dataclass -from typing import Optional import torch.nn as nn -from ..shard.shard_config import ShardConfig from .base_policy import Policy __all__ = ["PolicyLocation", "get_autopolicy", "import_policy"] @@ -150,39 +148,12 @@ _POLICY_LIST = { ), } -_INFER_POLICY_LIST = { - # LlaMa - "transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation( - file_name="llama", class_name="LlamaModelInferPolicy" - ), - "transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation( - file_name="llama", class_name="LlamaModelInferPolicy" - ), - # Bloom - "transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation( - file_name="bloom", class_name="BloomModelInferPolicy" - ), - "transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation( - file_name="bloom", class_name="BloomModelInferPolicy" - ), - # ChatGLM2 - "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation( - file_name="chatglm2", class_name="ChatGLM2InferPolicy" - ), - "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( - file_name="chatglm2", class_name="ChatGLM2ForConditionalGenerationInferPolicy" - ), -} - -def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy: +def import_policy(policy_location: PolicyLocation) -> Policy: """ Dynamically import a Policy class based on the policy location. """ - if inference_only: - module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}" - else: - module_name = f"colossalai.shardformer.policies.{policy_location.file_name}" + module_name = f"colossalai.shardformer.policies.{policy_location.file_name}" module = importlib.import_module(module_name) return getattr(module, policy_location.class_name) @@ -198,7 +169,7 @@ def _fullname(obj): return module + "." + klass.__qualname__ -def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy: +def get_autopolicy(model: nn.Module) -> Policy: r""" Return the auto policy for the model @@ -209,16 +180,12 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy :class:`Policy`: The auto policy for the model """ full_name = _fullname(model) - inference_only = shard_config.extra_kwargs.get("inference_only", None) - if inference_only: - policy_location = _INFER_POLICY_LIST.get(full_name, None) - else: - policy_location = _POLICY_LIST.get(full_name, None) + policy_location = _POLICY_LIST.get(full_name, None) if policy_location is None: raise NotImplementedError( - f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" + f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" ) else: - policy = import_policy(policy_location, inference_only) + policy = import_policy(policy_location) return policy() diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 0586ada9e..fc2f92778 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -28,7 +28,7 @@ class ModelSharder(object): def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None: self.model = model self.shard_config = shard_config - self.policy = get_autopolicy(self.model, shard_config) if policy is None else policy + self.policy = get_autopolicy(self.model) if policy is None else policy def shard(self) -> List[Dict[int, Tensor]]: r""" diff --git a/tests/test_infer/_utils.py b/tests/test_infer/_utils.py index 0bd791cc8..310c214f4 100644 --- a/tests/test_infer/_utils.py +++ b/tests/test_infer/_utils.py @@ -19,7 +19,6 @@ def build_model( enable_tensor_parallelism=enable_tensor_parallelism, enable_flash_attention=enable_flash_attention, enable_jit_fused=enable_jit_fused, - extra_kwargs={"inference_only": True}, ) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config)