diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 28ac2dc7f..bef39a6ca 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -96,7 +96,8 @@ class LlamaPolicy(Policy): target_key=attn_cls, ) - if self.pipeline_stage_manager is not None: + # if self.pipeline_stage_manager is not None: + if self.pipeline_stage_manager is None: self.append_or_create_method_replacement( description={ "forward": partial( diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 33707a4f6..b43e45bcf 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -325,6 +325,7 @@ def run_llama_test(test_config): ).get_v_schedule() test_config["scheduler_nodes"] = scheduler_nodes for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + print(f"name {name}") if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue try: