From c3d5fa3bac85baa07e30e2978a7517034ba7e0aa Mon Sep 17 00:00:00 2001 From: eric8607242 Date: Thu, 7 Sep 2023 10:15:13 +0800 Subject: [PATCH] [shardformer] Support customized policy for llamav2 based model with HybridParallelPlugin (#4624) * Enable policy assignment in HybridPlugin and enable llama policy for llamav2 * Remove Policy from Plugin * revert changes of plugin HybridParallelModule * revert changes in plugin * upgrade transformers * revert transformers version --------- Co-authored-by: flybird11111 <1829166702@qq.com> --- colossalai/shardformer/policies/llama.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index c417e5d01..875c87476 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -40,14 +40,20 @@ class LlamaPolicy(Policy): self.shard_config.enable_sequence_parallelism = False warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = \ + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + policy[LlamaDecoderLayer] = ModulePolicyDescription( - attribute_replacement={ - "self_attn.hidden_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, + attribute_replacement=decoder_attribute_replacement, sub_module_replacement=[ SubModuleReplacementDescription( suffix="self_attn.q_proj",