diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 58b36f624..f678d7d7f 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -45,7 +45,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): num_model_chunks: int, num_microbatch: Optional[int] = None, microbatch_size: Optional[int] = None, - enable_metadata_cache: bool = True, + enable_metadata_cache: bool = False, overlap_p2p: bool = True, ): super().__init__(stage_manager) @@ -679,6 +679,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss=accum_loss, outputs=outputs, ) + # print(f"stage {self.stage_manager.stage}; chunk {model_chunk_id}; output_obj {output_obj}") # Step3: # 3-1:detach output; detach output for send fwd; diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 9640d8187..b18aa933c 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -194,15 +194,7 @@ class LlamaPolicy(Policy): # not enable tp, replace layer to LinearWithGradAccum elif use_zbv: - decoder_attribute_replacement = { - "self_attn.hidden_size": self.model.config.hidden_size // tp_size, - "self_attn.num_heads": num_q_heads, - } - if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads - policy[LlamaDecoderLayer] = ModulePolicyDescription( - attribute_replacement=decoder_attribute_replacement, sub_module_replacement=[ SubModuleReplacementDescription( suffix="self_attn.q_proj", diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 4d16038c1..b4b87df92 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -10,6 +10,7 @@ from colossalai.shardformer.layer import ( FusedRMSNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, VocabParallelEmbedding1D, @@ -62,6 +63,8 @@ class MistralPolicy(Policy): if self.tie_weight: embedding_cls = PaddingEmbedding + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( @@ -90,6 +93,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -97,6 +101,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -104,6 +109,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -111,6 +117,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -118,6 +125,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -125,6 +133,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -132,6 +141,68 @@ class MistralPolicy(Policy): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + elif use_zbv: + policy[MistralDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 3d0966070..ddb70e5f2 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -36,6 +36,24 @@ NUM_HEADS = 4 TOP_K = 1 +def register_hooks(module: torch.nn.Module): + + def fwd_hook(module, input, output): + torch.cuda.synchronize() + name = module._name if hasattr(module, "_name") else module + print(f"Fwd hook {name} \n output {output}") + + def bwd_hook(module, grad_input, grad_output): + torch.cuda.synchronize() + + def bwd_pre_hook(module, grad_output): + torch.cuda.synchronize() + + module.register_forward_hook(fwd_hook) + # module.register_backward_hook(bwd_hook) + # module.register_full_backward_pre_hook(bwd_pre_hook) + + class MlpModel(nn.Module): def __init__( self, @@ -756,9 +774,9 @@ def run_fwd_bwd_vschedule_with_optim(test_config): (1, 2, 1, 1, 2), (1, 1, 2, 2, 1), (1, 2, 1, 2, 1), - # TODO: adapt mixtral with no TP Linear - # (1, 2, 2, 1, 1), - # (0, 1, 4, 1, 1), + (1, 2, 2, 1, 1), + # # TODO: adapt mixtral with no TP Linear + (0, 1, 4, 1, 1), ], ) def run_with_booster_moehybridplugin(config: Tuple[int, ...]): @@ -872,7 +890,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): return_outputs=True, ) # stage 0 chunk 0 - parallel_output = None if ( booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) and rank == dist.get_process_group_ranks(plugin.pp_group)[0] @@ -880,6 +897,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): parallel_output = sharded_output["loss"] else: parallel_output = torch.tensor(12345.0, device="cuda") + print(f"rank {dist.get_rank()} parallel_output {parallel_output}") # broadcast along pp axis dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group) @@ -920,8 +938,8 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): (1, 2, 2, 1), (1, 2, 1, 2), (1, 1, 2, 2), - # TODO: acc err in pp4 - # (1, 4, 1, 1), + # TODO: support overlap p2p in pp4 + (1, 4, 1, 1), ], ) def run_with_booster_hybridplugin(config: Tuple[int, ...]): @@ -1030,7 +1048,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): return_outputs=True, ) # stage 0 chunk 0 - parallel_output = None if ( booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) and rank == dist.get_process_group_ranks(plugin.pp_group)[0] @@ -1054,6 +1071,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): all_inputs = [input_embeddings.clone() for _ in range(dp_size)] dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) torch_output_sum = 0 + # torch_model.apply(register_hooks) # register hook for base model for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output.backward() @@ -1065,19 +1083,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): torch_optimizer.step() torch_optimizer.zero_grad() - # # assert param - # for parall_name, parall_param in parallel_model.named_parameters(): - # parall_name = ".".join(parall_name.split(".")[1:]) - # for base_name, base_param in torch_model.named_parameters(): - # if parall_name == base_name: - # # print(f"parall_name {parall_name} parall_param.grad {parall_param.grad is not None}, base_name {base_name} base_param.grad {base_param.grad is not None}") - # # # assert weight - # assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name) - # # # assert weight.grad - # if parall_param.grad is not None: - # # print(f"parall_param.grad {parall_param.grad}, base_param.grad {base_param.grad}") - # assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad") - + print(f"parallel_output {parallel_output}, torch_output_sum {torch_output_sum}") assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed") clear_layout_converter()