[Hotfix] Fix model policy matching strategy in ShardFormer (#5064)

* hotfix/Fix get model policy strategy in ShardFormer

* fix bug in auto policy
pull/5076/head
Zhongkai Zhao 1 year ago committed by GitHub
parent 4ccb9ded7d
commit 75af66cd81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -32,7 +32,7 @@ Colossal Inference is composed of three main components:
In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes. In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.
![Colossal-Inference](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-arch.png) <img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-arch.png" alt="Colossal-Inference" style="zoom: 33%;"/>
## Roadmap of our implementation ## Roadmap of our implementation

@ -1,10 +1,8 @@
import importlib import importlib
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import torch.nn as nn import torch.nn as nn
from ..shard.shard_config import ShardConfig
from .base_policy import Policy from .base_policy import Policy
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"] __all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
@ -150,39 +148,12 @@ _POLICY_LIST = {
), ),
} }
_INFER_POLICY_LIST = {
# LlaMa
"transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(
file_name="llama", class_name="LlamaModelInferPolicy"
),
"transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(
file_name="llama", class_name="LlamaModelInferPolicy"
),
# Bloom
"transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(
file_name="bloom", class_name="BloomModelInferPolicy"
),
"transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(
file_name="bloom", class_name="BloomModelInferPolicy"
),
# ChatGLM2
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
file_name="chatglm2", class_name="ChatGLM2InferPolicy"
),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLM2ForConditionalGenerationInferPolicy"
),
}
def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy: def import_policy(policy_location: PolicyLocation) -> Policy:
""" """
Dynamically import a Policy class based on the policy location. Dynamically import a Policy class based on the policy location.
""" """
if inference_only: module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}"
else:
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
return getattr(module, policy_location.class_name) return getattr(module, policy_location.class_name)
@ -198,7 +169,7 @@ def _fullname(obj):
return module + "." + klass.__qualname__ return module + "." + klass.__qualname__
def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy: def get_autopolicy(model: nn.Module) -> Policy:
r""" r"""
Return the auto policy for the model Return the auto policy for the model
@ -209,16 +180,12 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
:class:`Policy`: The auto policy for the model :class:`Policy`: The auto policy for the model
""" """
full_name = _fullname(model) full_name = _fullname(model)
inference_only = shard_config.extra_kwargs.get("inference_only", None) policy_location = _POLICY_LIST.get(full_name, None)
if inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
policy_location = _POLICY_LIST.get(full_name, None)
if policy_location is None: if policy_location is None:
raise NotImplementedError( raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
) )
else: else:
policy = import_policy(policy_location, inference_only) policy = import_policy(policy_location)
return policy() return policy()

@ -28,7 +28,7 @@ class ModelSharder(object):
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None: def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model self.model = model
self.shard_config = shard_config self.shard_config = shard_config
self.policy = get_autopolicy(self.model, shard_config) if policy is None else policy self.policy = get_autopolicy(self.model) if policy is None else policy
def shard(self) -> List[Dict[int, Tensor]]: def shard(self) -> List[Dict[int, Tensor]]:
r""" r"""

@ -19,7 +19,6 @@ def build_model(
enable_tensor_parallelism=enable_tensor_parallelism, enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention, enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused, enable_jit_fused=enable_jit_fused,
extra_kwargs={"inference_only": True},
) )
model_copy = copy.deepcopy(org_model) model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)

Loading…
Cancel
Save