mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
99 lines
3.9 KiB
99 lines
3.9 KiB
from typing import Dict, Union |
|
|
|
import torch.nn as nn |
|
|
|
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D |
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription |
|
|
|
|
|
class Grok1Policy(Policy): |
|
def config_sanity_check(self): |
|
pass |
|
|
|
def preprocess(self) -> nn.Module: |
|
if self.shard_config.enable_tensor_parallelism: |
|
vocab_size = self.model.config.vocab_size |
|
world_size = self.shard_config.tensor_parallel_size |
|
assert vocab_size % world_size == 0, f"vocab_size {vocab_size} must be divisible by world_size {world_size}" |
|
return self.model |
|
|
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: |
|
policy = {} |
|
if self.shard_config.enable_tensor_parallelism: |
|
decoder_attribute_replacement = { |
|
"attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, |
|
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, |
|
"attn.num_key_value_heads": self.model.config.num_key_value_heads |
|
// self.shard_config.tensor_parallel_size, |
|
} |
|
decoder_submodule_replacement = [ |
|
SubModuleReplacementDescription( |
|
suffix="attn.q_proj", |
|
target_module=Linear1D_Col, |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="attn.k_proj", |
|
target_module=Linear1D_Col, |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="attn.v_proj", |
|
target_module=Linear1D_Col, |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="attn.o_proj", |
|
target_module=Linear1D_Row, |
|
), |
|
] |
|
for i in range(self.model.config.num_experts): |
|
decoder_submodule_replacement.extend( |
|
[ |
|
SubModuleReplacementDescription( |
|
suffix=f"moe_block.experts[{i}].linear", |
|
target_module=Linear1D_Col, |
|
), |
|
SubModuleReplacementDescription( |
|
suffix=f"moe_block.experts[{i}].linear_v", |
|
target_module=Linear1D_Col, |
|
), |
|
SubModuleReplacementDescription( |
|
suffix=f"moe_block.experts[{i}].linear_1", |
|
target_module=Linear1D_Row, |
|
), |
|
] |
|
) |
|
|
|
policy["DecoderLayer"] = ModulePolicyDescription( |
|
attribute_replacement=decoder_attribute_replacement, |
|
sub_module_replacement=decoder_submodule_replacement, |
|
) |
|
self.append_or_create_submodule_replacement( |
|
description=SubModuleReplacementDescription( |
|
suffix="embed_tokens", |
|
target_module=VocabParallelEmbedding1D, |
|
), |
|
policy=policy, |
|
target_key="Grok1Model", |
|
) |
|
return policy |
|
|
|
def postprocess(self): |
|
return self.model |
|
|
|
|
|
class Grok1ModelPolicy(Grok1Policy): |
|
pass |
|
|
|
|
|
class Grok1ForCausalLMPolicy(Grok1Policy): |
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: |
|
policy = super().module_policy() |
|
self.append_or_create_submodule_replacement( |
|
description=SubModuleReplacementDescription( |
|
suffix="lm_head", |
|
target_module=Linear1D_Col, |
|
kwargs={"gather_output": not self.shard_config.parallel_output}, |
|
), |
|
policy=policy, |
|
target_key="Grok1ModelForCausalLM", |
|
) |
|
return policy
|
|
|