from typing import Dict, Union import torch.nn as nn from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class Grok1Policy(Policy): def config_sanity_check(self): pass def preprocess(self) -> nn.Module: if self.shard_config.enable_tensor_parallelism: vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size assert vocab_size % world_size == 0, f"vocab_size {vocab_size} must be divisible by world_size {world_size}" return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} if self.shard_config.enable_tensor_parallelism: decoder_attribute_replacement = { "attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "attn.num_key_value_heads": self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size, } decoder_submodule_replacement = [ SubModuleReplacementDescription( suffix="attn.q_proj", target_module=Linear1D_Col, ), SubModuleReplacementDescription( suffix="attn.k_proj", target_module=Linear1D_Col, ), SubModuleReplacementDescription( suffix="attn.v_proj", target_module=Linear1D_Col, ), SubModuleReplacementDescription( suffix="attn.o_proj", target_module=Linear1D_Row, ), ] for i in range(self.model.config.num_experts): decoder_submodule_replacement.extend( [ SubModuleReplacementDescription( suffix=f"moe_block.experts[{i}].linear", target_module=Linear1D_Col, ), SubModuleReplacementDescription( suffix=f"moe_block.experts[{i}].linear_v", target_module=Linear1D_Col, ), SubModuleReplacementDescription( suffix=f"moe_block.experts[{i}].linear_1", target_module=Linear1D_Row, ), ] ) policy["DecoderLayer"] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, sub_module_replacement=decoder_submodule_replacement, ) self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=VocabParallelEmbedding1D, ), policy=policy, target_key="Grok1Model", ) return policy def postprocess(self): return self.model class Grok1ModelPolicy(Grok1Policy): pass class Grok1ForCausalLMPolicy(Grok1Policy): def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = super().module_policy() self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output}, ), policy=policy, target_key="Grok1ModelForCausalLM", ) return policy