diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index c417e5d01..875c87476 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -40,14 +40,20 @@ class LlamaPolicy(Policy): self.shard_config.enable_sequence_parallelism = False warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = \ + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + policy[LlamaDecoderLayer] = ModulePolicyDescription( - attribute_replacement={ - "self_attn.hidden_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, + attribute_replacement=decoder_attribute_replacement, sub_module_replacement=[ SubModuleReplacementDescription( suffix="self_attn.q_proj",