diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 8a980bf9d..28ac2dc7f 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -394,8 +394,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy): return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - # if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv: - # return [] + if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv: + return [] llama_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if ( @@ -403,20 +403,26 @@ class LlamaForCausalLMPolicy(LlamaPolicy): and self.pipeline_stage_manager.num_stages > 1 ): # tie weights - if self.pipeline_stage_manager.use_zbv: - return [ - { - 0: llama_model.embed_tokens.weight, - 0: self.model.lm_head.weight, - } - ] - else: - return [ - { - 0: llama_model.embed_tokens.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, - } - ] + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + # if self.pipeline_stage_manager.use_zbv: + # return [ + # { + # 0: llama_model.embed_tokens.weight, + # 0: self.model.lm_head.weight, + # } + # ] + # else: + # return [ + # { + # 0: llama_model.embed_tokens.weight, + # self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + # } + # ] return [] diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 0d80bc225..b60bdd03e 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -237,7 +237,6 @@ def main(): ).get_v_schedule() else: scheduler_nodes = None - # print(f"{dist.get_rank()} {scheduler_nodes[]} ") plugin = HybridParallelPlugin( tp_size=args.tp, diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 5f286d173..c485d3f54 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -923,9 +923,9 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): @parameterize( "config", [ - # (1, 2, 2, 1), # Pass + (1, 2, 2, 1), # Pass # TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture; - (0, 4, 1, 1), + # (0, 4, 1, 1), # (1, 2, 1, 2), # (1, 1, 2, 2), ],