diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 713175c6c..a9c982231 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -351,7 +351,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy): policy = super().module_policy() - if self.shard_config.enable_tensor_parallelism and not self.shard_config.enable_sequence_parallelism: + if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { LlamaForCausalLM: ModulePolicyDescription(