mirror of https://github.com/hpcaitech/ColossalAI
100 lines
3.9 KiB
Python
100 lines
3.9 KiB
Python
|
from typing import Dict, Union
|
||
|
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||
|
|
||
|
|
||
|
class Grok1Policy(Policy):
|
||
|
def config_sanity_check(self):
|
||
|
pass
|
||
|
|
||
|
def preprocess(self) -> nn.Module:
|
||
|
if self.shard_config.enable_tensor_parallelism:
|
||
|
vocab_size = self.model.config.vocab_size
|
||
|
world_size = self.shard_config.tensor_parallel_size
|
||
|
assert vocab_size % world_size == 0, f"vocab_size {vocab_size} must be divisible by world_size {world_size}"
|
||
|
return self.model
|
||
|
|
||
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||
|
policy = {}
|
||
|
if self.shard_config.enable_tensor_parallelism:
|
||
|
decoder_attribute_replacement = {
|
||
|
"attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||
|
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||
|
"attn.num_key_value_heads": self.model.config.num_key_value_heads
|
||
|
// self.shard_config.tensor_parallel_size,
|
||
|
}
|
||
|
decoder_submodule_replacement = [
|
||
|
SubModuleReplacementDescription(
|
||
|
suffix="attn.q_proj",
|
||
|
target_module=Linear1D_Col,
|
||
|
),
|
||
|
SubModuleReplacementDescription(
|
||
|
suffix="attn.k_proj",
|
||
|
target_module=Linear1D_Col,
|
||
|
),
|
||
|
SubModuleReplacementDescription(
|
||
|
suffix="attn.v_proj",
|
||
|
target_module=Linear1D_Col,
|
||
|
),
|
||
|
SubModuleReplacementDescription(
|
||
|
suffix="attn.o_proj",
|
||
|
target_module=Linear1D_Row,
|
||
|
),
|
||
|
]
|
||
|
for i in range(self.model.config.num_experts):
|
||
|
decoder_submodule_replacement.extend(
|
||
|
[
|
||
|
SubModuleReplacementDescription(
|
||
|
suffix=f"moe_block.experts[{i}].linear",
|
||
|
target_module=Linear1D_Col,
|
||
|
),
|
||
|
SubModuleReplacementDescription(
|
||
|
suffix=f"moe_block.experts[{i}].linear_v",
|
||
|
target_module=Linear1D_Col,
|
||
|
),
|
||
|
SubModuleReplacementDescription(
|
||
|
suffix=f"moe_block.experts[{i}].linear_1",
|
||
|
target_module=Linear1D_Row,
|
||
|
),
|
||
|
]
|
||
|
)
|
||
|
|
||
|
policy["DecoderLayer"] = ModulePolicyDescription(
|
||
|
attribute_replacement=decoder_attribute_replacement,
|
||
|
sub_module_replacement=decoder_submodule_replacement,
|
||
|
)
|
||
|
self.append_or_create_submodule_replacement(
|
||
|
description=SubModuleReplacementDescription(
|
||
|
suffix="embed_tokens",
|
||
|
target_module=VocabParallelEmbedding1D,
|
||
|
),
|
||
|
policy=policy,
|
||
|
target_key="Grok1Model",
|
||
|
)
|
||
|
return policy
|
||
|
|
||
|
def postprocess(self):
|
||
|
return self.model
|
||
|
|
||
|
|
||
|
class Grok1ModelPolicy(Grok1Policy):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class Grok1ForCausalLMPolicy(Grok1Policy):
|
||
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||
|
policy = super().module_policy()
|
||
|
self.append_or_create_submodule_replacement(
|
||
|
description=SubModuleReplacementDescription(
|
||
|
suffix="lm_head",
|
||
|
target_module=Linear1D_Col,
|
||
|
kwargs={"gather_output": not self.shard_config.parallel_output},
|
||
|
),
|
||
|
policy=policy,
|
||
|
target_key="Grok1ModelForCausalLM",
|
||
|
)
|
||
|
return policy
|