mirror of https://github.com/hpcaitech/ColossalAI
[Hotfix] Fix model policy matching strategy in ShardFormer (#5064)
* hotfix/Fix get model policy strategy in ShardFormer * fix bug in auto policypull/5076/head
parent
4ccb9ded7d
commit
75af66cd81
|
@ -32,7 +32,7 @@ Colossal Inference is composed of three main components:
|
|||
|
||||
In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.
|
||||
|
||||

|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-arch.png" alt="Colossal-Inference" style="zoom: 33%;"/>
|
||||
|
||||
## Roadmap of our implementation
|
||||
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
import importlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..shard.shard_config import ShardConfig
|
||||
from .base_policy import Policy
|
||||
|
||||
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
|
||||
|
@ -150,39 +148,12 @@ _POLICY_LIST = {
|
|||
),
|
||||
}
|
||||
|
||||
_INFER_POLICY_LIST = {
|
||||
# LlaMa
|
||||
"transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(
|
||||
file_name="llama", class_name="LlamaModelInferPolicy"
|
||||
),
|
||||
"transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(
|
||||
file_name="llama", class_name="LlamaModelInferPolicy"
|
||||
),
|
||||
# Bloom
|
||||
"transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(
|
||||
file_name="bloom", class_name="BloomModelInferPolicy"
|
||||
),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(
|
||||
file_name="bloom", class_name="BloomModelInferPolicy"
|
||||
),
|
||||
# ChatGLM2
|
||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
|
||||
file_name="chatglm2", class_name="ChatGLM2InferPolicy"
|
||||
),
|
||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
|
||||
file_name="chatglm2", class_name="ChatGLM2ForConditionalGenerationInferPolicy"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy:
|
||||
def import_policy(policy_location: PolicyLocation) -> Policy:
|
||||
"""
|
||||
Dynamically import a Policy class based on the policy location.
|
||||
"""
|
||||
if inference_only:
|
||||
module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}"
|
||||
else:
|
||||
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
|
||||
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, policy_location.class_name)
|
||||
|
||||
|
@ -198,7 +169,7 @@ def _fullname(obj):
|
|||
return module + "." + klass.__qualname__
|
||||
|
||||
|
||||
def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy:
|
||||
def get_autopolicy(model: nn.Module) -> Policy:
|
||||
r"""
|
||||
Return the auto policy for the model
|
||||
|
||||
|
@ -209,16 +180,12 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
|
|||
:class:`Policy`: The auto policy for the model
|
||||
"""
|
||||
full_name = _fullname(model)
|
||||
inference_only = shard_config.extra_kwargs.get("inference_only", None)
|
||||
if inference_only:
|
||||
policy_location = _INFER_POLICY_LIST.get(full_name, None)
|
||||
else:
|
||||
policy_location = _POLICY_LIST.get(full_name, None)
|
||||
policy_location = _POLICY_LIST.get(full_name, None)
|
||||
|
||||
if policy_location is None:
|
||||
raise NotImplementedError(
|
||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
|
||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
|
||||
)
|
||||
else:
|
||||
policy = import_policy(policy_location, inference_only)
|
||||
policy = import_policy(policy_location)
|
||||
return policy()
|
||||
|
|
|
@ -28,7 +28,7 @@ class ModelSharder(object):
|
|||
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
|
||||
self.model = model
|
||||
self.shard_config = shard_config
|
||||
self.policy = get_autopolicy(self.model, shard_config) if policy is None else policy
|
||||
self.policy = get_autopolicy(self.model) if policy is None else policy
|
||||
|
||||
def shard(self) -> List[Dict[int, Tensor]]:
|
||||
r"""
|
||||
|
|
|
@ -19,7 +19,6 @@ def build_model(
|
|||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_jit_fused=enable_jit_fused,
|
||||
extra_kwargs={"inference_only": True},
|
||||
)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
|
|
Loading…
Reference in New Issue