mirror of https://github.com/hpcaitech/ColossalAI
[Hotfix] Fix model policy matching strategy in ShardFormer (#5064)
* hotfix/Fix get model policy strategy in ShardFormer * fix bug in auto policypull/5076/head
parent
4ccb9ded7d
commit
75af66cd81
|
@ -32,7 +32,7 @@ Colossal Inference is composed of three main components:
|
||||||
|
|
||||||
In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.
|
In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.
|
||||||
|
|
||||||

|
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-arch.png" alt="Colossal-Inference" style="zoom: 33%;"/>
|
||||||
|
|
||||||
## Roadmap of our implementation
|
## Roadmap of our implementation
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
import importlib
|
import importlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ..shard.shard_config import ShardConfig
|
|
||||||
from .base_policy import Policy
|
from .base_policy import Policy
|
||||||
|
|
||||||
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
|
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
|
||||||
|
@ -150,39 +148,12 @@ _POLICY_LIST = {
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
_INFER_POLICY_LIST = {
|
|
||||||
# LlaMa
|
|
||||||
"transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(
|
|
||||||
file_name="llama", class_name="LlamaModelInferPolicy"
|
|
||||||
),
|
|
||||||
"transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(
|
|
||||||
file_name="llama", class_name="LlamaModelInferPolicy"
|
|
||||||
),
|
|
||||||
# Bloom
|
|
||||||
"transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(
|
|
||||||
file_name="bloom", class_name="BloomModelInferPolicy"
|
|
||||||
),
|
|
||||||
"transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(
|
|
||||||
file_name="bloom", class_name="BloomModelInferPolicy"
|
|
||||||
),
|
|
||||||
# ChatGLM2
|
|
||||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
|
|
||||||
file_name="chatglm2", class_name="ChatGLM2InferPolicy"
|
|
||||||
),
|
|
||||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
|
|
||||||
file_name="chatglm2", class_name="ChatGLM2ForConditionalGenerationInferPolicy"
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
def import_policy(policy_location: PolicyLocation) -> Policy:
|
||||||
def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy:
|
|
||||||
"""
|
"""
|
||||||
Dynamically import a Policy class based on the policy location.
|
Dynamically import a Policy class based on the policy location.
|
||||||
"""
|
"""
|
||||||
if inference_only:
|
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
|
||||||
module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}"
|
|
||||||
else:
|
|
||||||
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
|
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
return getattr(module, policy_location.class_name)
|
return getattr(module, policy_location.class_name)
|
||||||
|
|
||||||
|
@ -198,7 +169,7 @@ def _fullname(obj):
|
||||||
return module + "." + klass.__qualname__
|
return module + "." + klass.__qualname__
|
||||||
|
|
||||||
|
|
||||||
def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy:
|
def get_autopolicy(model: nn.Module) -> Policy:
|
||||||
r"""
|
r"""
|
||||||
Return the auto policy for the model
|
Return the auto policy for the model
|
||||||
|
|
||||||
|
@ -209,16 +180,12 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
|
||||||
:class:`Policy`: The auto policy for the model
|
:class:`Policy`: The auto policy for the model
|
||||||
"""
|
"""
|
||||||
full_name = _fullname(model)
|
full_name = _fullname(model)
|
||||||
inference_only = shard_config.extra_kwargs.get("inference_only", None)
|
policy_location = _POLICY_LIST.get(full_name, None)
|
||||||
if inference_only:
|
|
||||||
policy_location = _INFER_POLICY_LIST.get(full_name, None)
|
|
||||||
else:
|
|
||||||
policy_location = _POLICY_LIST.get(full_name, None)
|
|
||||||
|
|
||||||
if policy_location is None:
|
if policy_location is None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
|
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
policy = import_policy(policy_location, inference_only)
|
policy = import_policy(policy_location)
|
||||||
return policy()
|
return policy()
|
||||||
|
|
|
@ -28,7 +28,7 @@ class ModelSharder(object):
|
||||||
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
|
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
self.policy = get_autopolicy(self.model, shard_config) if policy is None else policy
|
self.policy = get_autopolicy(self.model) if policy is None else policy
|
||||||
|
|
||||||
def shard(self) -> List[Dict[int, Tensor]]:
|
def shard(self) -> List[Dict[int, Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -19,7 +19,6 @@ def build_model(
|
||||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||||
enable_flash_attention=enable_flash_attention,
|
enable_flash_attention=enable_flash_attention,
|
||||||
enable_jit_fused=enable_jit_fused,
|
enable_jit_fused=enable_jit_fused,
|
||||||
extra_kwargs={"inference_only": True},
|
|
||||||
)
|
)
|
||||||
model_copy = copy.deepcopy(org_model)
|
model_copy = copy.deepcopy(org_model)
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
|
|
Loading…
Reference in New Issue