diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 673701017..bba943f12 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1166,22 +1166,6 @@ class HybridParallelPlugin(PipelinePluginBase): num_microbatch=num_microbatches, microbatch_size=microbatch_size, ) - elif pp_style == "zbv": - self.scheduler = ZeroBubbleVPipeScheduler( - stage_manager=self.stage_manager, - schedule=scheduler_nodes, - num_model_chunks=num_model_chunks, - num_microbatch=num_microbatches, - microbatch_size=microbatch_size, - ) - elif pp_style == "zbv": - self.scheduler = ZeroBubbleVPipeScheduler( - stage_manager=self.stage_manager, - schedule=scheduler_nodes, - num_model_chunks=num_model_chunks, - num_microbatch=num_microbatches, - microbatch_size=microbatch_size, - ) else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 5d48a16c3..5c68d0c5e 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -289,9 +289,9 @@ class LlamaPolicy(Policy): held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): - held_layers.append(module.norm) - elif stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(module.norm) else: @@ -383,13 +383,15 @@ class LlamaForCausalLMPolicy(LlamaPolicy): """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) - elif stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: + if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv: + return [] llama_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if ( @@ -443,9 +445,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): - held_layers.append(self.model.score) - elif stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(self.model.score) return held_layers diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 6286cc6f0..a56a68cd3 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from copy import deepcopy from functools import partial from typing import Tuple @@ -72,6 +73,9 @@ class MlpModel(nn.Module): else: return {"hidden_states": held_layers(hidden_states)} + def no_sync(self): + return nullcontext() + def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups): for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index f3b4db1ce..04ef78221 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -114,14 +114,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss check_flag = False - if stage_manager is None: + if ( + (stage_manager is None) + or (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) + or (not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)) + ): check_flag = True - else: - if stage_manager.use_zbv: - if stage_manager.is_first_stage(ignore_chunk=True): - check_flag = True - elif stage_manager.is_last_stage(ignore_chunk=True): - check_flag = True if check_flag: if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 @@ -292,6 +290,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_gradient_checkpointing": True, "parallel_output": False, }, + { + "tp_size": 2, + "pp_size": 2, + "pp_style": "zbv", + "num_model_chunks": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "parallel_output": False, + }, ], ) def run_llama_test(test_config):