From 4fc92aa77dafd4a8253ed5ea4c16f090b44d2744 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 5 Nov 2024 05:55:42 +0000 Subject: [PATCH 01/16] [feat] support no_tp Linear for sharderformer.llama --- .../pipeline/schedule/zero_bubble_pp.py | 53 ++++----- colossalai/shardformer/modeling/llama.py | 1 - colossalai/shardformer/policies/llama.py | 101 +++++++++++++++++- examples/language/llama/benchmark.py | 7 -- .../test_schedule/test_zerobubble_pp.py | 20 ++-- 5 files changed, 140 insertions(+), 42 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index e310e9bf3..bab118b85 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -64,10 +64,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # P2PMeta cache self.enable_metadata_cache = enable_metadata_cache - self.send_tensor_metadata = True - self.send_grad_metadata = True - self.tensor_metadata_recv = None - self.grad_metadata_recv = None + self.send_tensor_metadata = [True, True] + self.send_grad_metadata = [True, True] + # meta cache buffer + self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta] + self.grad_metadata_recv = [None, None] # P2P communication self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) @@ -235,10 +236,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: prev_rank = self.stage_manager.get_prev_rank() input_tensor, wait_handles = self.comm.recv_forward( - prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv + prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id] ) - if self.enable_metadata_cache and self.tensor_metadata_recv is None: - self.tensor_metadata_recv = create_send_metadata(input_tensor) + if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: + self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor) # return input_tensor, wait_handles return wait_handles @@ -259,10 +260,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: next_rank = self.stage_manager.get_next_rank() input_tensor, wait_handles = self.comm.recv_forward( - next_rank, metadata_recv=self.tensor_metadata_recv + next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id] ) - if self.enable_metadata_cache and self.tensor_metadata_recv is None: - self.tensor_metadata_recv = create_send_metadata(input_tensor) + if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: + self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor) # return input_tensor, wait_handles return wait_handles @@ -297,10 +298,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: next_rank = self.stage_manager.get_next_rank() output_tensor_grad, wait_handles = self.comm.recv_backward( - next_rank, metadata_recv=self.grad_metadata_recv + next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id] ) - if self.enable_metadata_cache and self.grad_metadata_recv is None: - self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: + self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) # return output_tensor_grad, wait_handles return wait_handles @@ -322,10 +323,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: prev_rank = self.stage_manager.get_prev_rank() output_tensor_grad, wait_handles = self.comm.recv_backward( - next_rank=prev_rank, metadata_recv=self.grad_metadata_recv + next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id] ) - if self.enable_metadata_cache and self.grad_metadata_recv is None: - self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: + self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) # return output_tensor_grad, wait_handles return wait_handles @@ -359,9 +360,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): next_rank = self.stage_manager.get_next_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward( - output_object=output_tensor, next_rank=next_rank, send_metadata=self.send_tensor_metadata + output_object=output_tensor, + next_rank=next_rank, + send_metadata=self.send_tensor_metadata[model_chunk_id], ) - self.send_tensor_metadata = not self.enable_metadata_cache + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache return send_handles else: @@ -380,9 +383,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): prev_rank = self.stage_manager.get_prev_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward( - output_tensor, prev_rank, send_metadata=self.send_tensor_metadata + output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id] ) - self.send_tensor_metadata = not self.enable_metadata_cache + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache return send_handles def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: @@ -415,9 +418,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): prev_rank = self.stage_manager.get_prev_rank() input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_backward( - input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata + input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id] ) - self.send_grad_metadata = not self.enable_metadata_cache + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache return send_handles # bwd chunk1 is left V; @@ -437,9 +440,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): next_rank = self.stage_manager.get_next_rank() input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_backward( - input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata + input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id] ) - self.send_grad_metadata = not self.enable_metadata_cache + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache return send_handles def forward_step( @@ -662,6 +665,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss=accum_loss, outputs=outputs, ) + # print(f"stage {self.stage_manager.stage}; model_chunk_id {model_chunk_id}; output_obj {output_obj};") # Step3: # 3-1:detach output; detach output for send fwd; @@ -886,6 +890,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) for it in range(len(schedule)): scheduled_node = schedule[it] + # print(f"rank {torch.distributed.get_rank()}; stage {self.stage_manager.stage}; scheduled_node {scheduled_node};") if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication communication_func = self.communication_map[scheduled_node.type] diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a51a1df9f..d1ad84604 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -191,7 +191,6 @@ class LlamaPipelineForwards: num_model_chunks=stage_manager.num_model_chunks, ) assert num_ckpt_layers <= end_idx - start_idx - for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 2b3a30bad..528638f41 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -9,6 +9,7 @@ from colossalai.shardformer.layer import ( FusedRMSNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, RMSNorm, @@ -104,7 +105,7 @@ class LlamaPolicy(Policy): policy=policy, target_key=LlamaModel, ) - + # enable tp, replace layer to tp Linear1D_Col,Linear1D_Row, if self.shard_config.enable_tensor_parallelism: assert ( num_q_heads % tp_size == 0 @@ -191,6 +192,84 @@ class LlamaPolicy(Policy): ], ) + # not enable tp, replace layer to LinearWithGradAccum + else: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // tp_size, + "self_attn.num_heads": num_q_heads, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads + + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + ], + ) + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -416,6 +495,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): policy = super().module_policy() use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + # enable tp, replace layer to tp Linear1D_Col,Linear1D_Row, if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification new_item = { @@ -434,6 +514,25 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): ) } policy.update(new_item) + # enable tp, replace layer to LinearWithGradAccum + else: + # add a new item for sequence classification + new_item = { + LlamaForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", + target_module=LinearWithGradAccum, + kwargs=dict( + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ] + ) + } + policy.update(new_item) + # to be confirmed if self.pipeline_stage_manager: # set None as default diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index ad5d35161..4976f0c37 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -163,8 +163,6 @@ def main(): enable_async_reduce=not args.disable_async_reduce, use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, - use_fp8=args.use_fp8, - fp8_communication=args.use_fp8_comm, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -179,8 +177,6 @@ def main(): enable_flash_attention=args.xformers, use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, - use_fp8=args.use_fp8, - fp8_communication=args.use_fp8_comm, ) elif args.plugin == "fsdp": if use_empty_init: @@ -192,7 +188,6 @@ def main(): ), param_init_fn=empty_init(), fp8_communication=args.use_fp8_comm, - fp8_communication=args.use_fp8_comm, ) else: plugin = TorchFSDPPlugin( @@ -214,7 +209,6 @@ def main(): cpu_offload=CPUOffload(offload_params=True), param_init_fn=empty_init(), fp8_communication=args.use_fp8_comm, - fp8_communication=args.use_fp8_comm, ) else: plugin = TorchFSDPPlugin( @@ -225,7 +219,6 @@ def main(): ), cpu_offload=CPUOffload(offload_params=True), fp8_communication=args.use_fp8_comm, - fp8_communication=args.use_fp8_comm, ) elif args.plugin == "3d": if args.pp_style == "zbv": diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 71ff11059..b8ef09bea 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -758,11 +758,13 @@ def run_with_hybridplugin(test_config): @parameterize( "config", [ - # (0, 1, 4, 1, 1), + # # Pass + (1, 2, 1, 1, 2), + # TODO: adapt mixtral with no TP Linear # (1, 2, 2, 1, 1), - (1, 1, 2, 2, 1), + # (0, 1, 4, 1, 1), + # (1, 1, 2, 2, 1), # (1, 2, 1, 2, 1), - # (1, 2, 1, 1, 2), ], ) def run_with_booster_moehybridplugin(config: Tuple[int, ...]): @@ -910,7 +912,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): p.grad /= dp_size torch_optimizer.step() torch_optimizer.zero_grad() - assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) print(f"rank {dist.get_rank()} config {test_config} test passed") clear_layout_converter() @@ -921,11 +922,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): @parameterize( "config", [ - (1, 2, 2, 1), # Pass - # TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture; - # (0, 4, 1, 1), - # (1, 2, 1, 2), - # (1, 1, 2, 2), + # # Pass + (1, 2, 2, 1), + (1, 2, 1, 2), + (1, 1, 2, 2), + # TODO: acc err in pp4 + (1, 4, 1, 1), ], ) def run_with_booster_hybridplugin(config: Tuple[int, ...]): From 0d6d40ccc62b5eaa514c7f4f8cc525ce159ff038 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 6 Nov 2024 03:35:12 +0000 Subject: [PATCH 02/16] [fix] fix zbv llama pp4 --- .../pipeline/schedule/zero_bubble_pp.py | 33 ------------------- .../test_schedule/test_zerobubble_pp.py | 25 ++++++++------ .../test_model/test_shard_llama.py | 2 +- 3 files changed, 16 insertions(+), 44 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index bab118b85..7bdb6d11e 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -226,7 +226,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; cause u are chunk 0 in first rank, u have no prev rank; ################# if self.stage_manager.is_first_stage(ignore_chunk=True): - # return None, [] return [] ################ @@ -241,7 +240,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor) - # return input_tensor, wait_handles return wait_handles else: @@ -265,7 +263,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor) - # return input_tensor, wait_handles return wait_handles def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List: @@ -313,7 +310,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; get loss from local ################ if self.stage_manager.is_first_stage(ignore_chunk=True): - # return None, [] return [] ################ @@ -328,7 +324,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) - # return output_tensor_grad, wait_handles return wait_handles def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: @@ -665,7 +660,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss=accum_loss, outputs=outputs, ) - # print(f"stage {self.stage_manager.stage}; model_chunk_id {model_chunk_id}; output_obj {output_obj};") # Step3: # 3-1:detach output; detach output for send fwd; @@ -748,20 +742,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) - # # save output_tensor_grad for dw - # if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # # we save loss here - # self.output_tensors_grad_dw[model_chunk_id].append(output_obj) - # else: - # # we save output_tensor_grad here - # self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) - # the_output_obj_grad = [] - # if isinstance(output_obj, dict): - # for (k, v) in output_obj.items(): - # the_output_obj_grad.append(v.requires_grad) - # else: - # the_output_obj_grad.append(output_obj.requires_grad) - input_object_grad = self.backward_b_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, @@ -804,20 +784,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): Returns: Nothing. """ - - # get y & dy from buffer - # output_obj = self.output_tensors_dw[model_chunk_id].pop(0) - # output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) WeightGradStore.pop(chunk=model_chunk_id) - # self.backward_w_step( - # model_chunk=model_chunk, - # model_chunk_id=model_chunk_id, - # optimizer=optimizer, - # output_obj=output_obj, - # output_obj_grad=output_obj_grad, - # ) - def run_forward_only( self, model_chunk: Union[ModuleList, Module], @@ -890,7 +858,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) for it in range(len(schedule)): scheduled_node = schedule[it] - # print(f"rank {torch.distributed.get_rank()}; stage {self.stage_manager.stage}; scheduled_node {scheduled_node};") if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication communication_func = self.communication_map[scheduled_node.type] diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index b8ef09bea..bda3a5512 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -749,12 +749,6 @@ def run_fwd_bwd_vschedule_with_optim(test_config): assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups) -# TODO:3) support booster & Hybrid base 2) -def run_with_hybridplugin(test_config): - pass - - -# TODO:4) support booster & MoEHybrid base 2) @parameterize( "config", [ @@ -923,9 +917,9 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): "config", [ # # Pass - (1, 2, 2, 1), - (1, 2, 1, 2), - (1, 1, 2, 2), + # (1, 2, 2, 1), + # (1, 2, 1, 2), + # (1, 1, 2, 2), # TODO: acc err in pp4 (1, 4, 1, 1), ], @@ -1071,6 +1065,17 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): torch_optimizer.step() torch_optimizer.zero_grad() + # assert param + for parall_name, parall_param in parallel_model.named_parameters(): + parall_name = ".".join(parall_name.split(".")[1:]) + for base_name, base_param in torch_model.named_parameters(): + if parall_name == base_name: + # assert weight + assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name) + # assert weight.grad + if parall_param.grad is not None: + assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad") + assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed") clear_layout_converter() @@ -1081,7 +1086,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_with_booster_moehybridplugin() + # run_with_booster_moehybridplugin() run_with_booster_hybridplugin() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 33707a4f6..c0690e5fd 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -420,4 +420,4 @@ def test_llama_3d(): if __name__ == "__main__": test_llama() - test_llama_3d() + # test_llama_3d() From 12919de424de5acf1bb9fe3f409ece8ad41ab9ef Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 11 Nov 2024 08:54:39 +0000 Subject: [PATCH 03/16] [fix] fix send_tensor_metadata & send_grad_metadata; --- colossalai/pipeline/p2p.py | 1 + .../pipeline/schedule/zero_bubble_pp.py | 32 ++++++++++++--- colossalai/shardformer/policies/llama.py | 39 ++++++++++--------- .../test_schedule/test_zerobubble_pp.py | 32 +++++++-------- .../test_model/test_shard_llama.py | 2 +- 5 files changed, 65 insertions(+), 41 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index b7b284213..8dbb6ec78 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -432,6 +432,7 @@ def _communicate( overlap_p2p=overlap_p2p, send_first=send_first if send_first != None else True, ) + # print(f"rank {dist.get_rank()}; recv_src {recv_src}; send_dst {send_dst}; metadata_send {metadata_send}; metadata_recv {metadata_recv};") if metadata_recv is not None: assert isinstance(metadata_recv, P2PMetadata) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 7bdb6d11e..b608fc3a0 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -64,8 +64,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # P2PMeta cache self.enable_metadata_cache = enable_metadata_cache - self.send_tensor_metadata = [True, True] - self.send_grad_metadata = [True, True] + + # check send_tensor_metadata, send_grad_metadata + # pp4 as sample, we should follow this meta strategy + # send_tensor_meta(fwd) send_grad_meta(bwd) + # chunk0 | chunk1 chunk0 | chunk 1 + # stage 0 T | F F | T + # stage 1 T | T T | T + # stage 2 T | T T | T + # stage 3 F | T F | T + if stage_manager.is_first_stage(ignore_chunk=True): + self.send_tensor_metadata = [True, False] + self.send_grad_metadata = [False, True] + elif stage_manager.is_last_stage(ignore_chunk=True): + self.send_tensor_metadata = [False, True] + self.send_grad_metadata = [True, False] + else: + self.send_tensor_metadata = [True, True] + self.send_grad_metadata = [True, True] + # meta cache buffer self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta] self.grad_metadata_recv = [None, None] @@ -84,6 +101,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # init buffer self._free_buffers() + def _set_send_metadata_buffers(self, model_chunk_id): + pass + def _free_buffers(self): # free local buffer # two dim array, first dim is the model chunk, second dim is the microbatch queue @@ -285,7 +305,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; Already get dy from local_send_backward_buffer in schedule b ################ if self.stage_manager.is_last_stage(ignore_chunk=True): - # return None, [] return [] ################ @@ -300,7 +319,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) - # return output_tensor_grad, wait_handles return wait_handles else: @@ -345,6 +363,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; hold y on local_send_forward_buffer ################ if self.stage_manager.is_last_stage(ignore_chunk=True): + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache return [] ################ @@ -368,6 +387,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part ################ if self.stage_manager.is_first_stage(ignore_chunk=True): + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache return [] ################ @@ -403,6 +423,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; cause u are the first chunk in first stage; bwd end ################ if self.stage_manager.is_first_stage(ignore_chunk=True): + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache return [] ################ @@ -425,6 +446,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b; ################ if self.stage_manager.is_last_stage(ignore_chunk=True): + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache return [] ################ @@ -889,7 +911,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): for h in self.wait_handles: for hh in h: hh.wait() - + # print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}") # return loss & output if outputs is not None: outputs = merge_batch(outputs) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 528638f41..9640d8187 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -193,7 +193,7 @@ class LlamaPolicy(Policy): ) # not enable tp, replace layer to LinearWithGradAccum - else: + elif use_zbv: decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // tp_size, "self_attn.num_heads": num_q_heads, @@ -514,24 +514,25 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): ) } policy.update(new_item) - # enable tp, replace layer to LinearWithGradAccum - else: - # add a new item for sequence classification - new_item = { - LlamaForSequenceClassification: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="score", - target_module=LinearWithGradAccum, - kwargs=dict( - fp8_communication=self.shard_config.fp8_communication, - use_zbv=use_zbv, - ), - ) - ] - ) - } - policy.update(new_item) + # TODO: test lora bug here + # # enable tp, replace layer to LinearWithGradAccum + # else: + # # add a new item for sequence classification + # new_item = { + # LlamaForSequenceClassification: ModulePolicyDescription( + # sub_module_replacement=[ + # SubModuleReplacementDescription( + # suffix="score", + # target_module=LinearWithGradAccum, + # kwargs=dict( + # fp8_communication=self.shard_config.fp8_communication, + # use_zbv=use_zbv, + # ), + # ) + # ] + # ) + # } + # policy.update(new_item) # to be confirmed if self.pipeline_stage_manager: diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index bda3a5512..81e4c888f 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -916,12 +916,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): @parameterize( "config", [ - # # Pass - # (1, 2, 2, 1), - # (1, 2, 1, 2), - # (1, 1, 2, 2), + # Pass + (1, 2, 2, 1), + (1, 2, 1, 2), + (1, 1, 2, 2), # TODO: acc err in pp4 - (1, 4, 1, 1), + # (1, 4, 1, 1), ], ) def run_with_booster_hybridplugin(config: Tuple[int, ...]): @@ -1065,16 +1065,16 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): torch_optimizer.step() torch_optimizer.zero_grad() - # assert param - for parall_name, parall_param in parallel_model.named_parameters(): - parall_name = ".".join(parall_name.split(".")[1:]) - for base_name, base_param in torch_model.named_parameters(): - if parall_name == base_name: - # assert weight - assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name) - # assert weight.grad - if parall_param.grad is not None: - assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad") + # # assert param + # for parall_name, parall_param in parallel_model.named_parameters(): + # parall_name = ".".join(parall_name.split(".")[1:]) + # for base_name, base_param in torch_model.named_parameters(): + # if parall_name == base_name: + # # assert weight + # assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name) + # # assert weight.grad + # if parall_param.grad is not None: + # assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad") assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed") @@ -1086,7 +1086,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - # run_with_booster_moehybridplugin() + run_with_booster_moehybridplugin() run_with_booster_hybridplugin() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index c0690e5fd..33707a4f6 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -420,4 +420,4 @@ def test_llama_3d(): if __name__ == "__main__": test_llama() - # test_llama_3d() + test_llama_3d() From 337debcf2a7a894a7d4501e8b07e78844a7e7bfa Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 11 Nov 2024 11:34:29 +0000 Subject: [PATCH 04/16] [feat] fix testcase; --- colossalai/pipeline/p2p.py | 2 -- colossalai/pipeline/schedule/zero_bubble_pp.py | 3 --- .../test_schedule/test_zerobubble_pp.py | 12 +++++++----- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 8dbb6ec78..8c319aceb 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -432,8 +432,6 @@ def _communicate( overlap_p2p=overlap_p2p, send_first=send_first if send_first != None else True, ) - # print(f"rank {dist.get_rank()}; recv_src {recv_src}; send_dst {send_dst}; metadata_send {metadata_send}; metadata_recv {metadata_recv};") - if metadata_recv is not None: assert isinstance(metadata_recv, P2PMetadata) tree_spec = metadata_recv.tree_spec diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index b608fc3a0..58b36f624 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -101,9 +101,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # init buffer self._free_buffers() - def _set_send_metadata_buffers(self, model_chunk_id): - pass - def _free_buffers(self): # free local buffer # two dim array, first dim is the model chunk, second dim is the microbatch queue diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 81e4c888f..3d0966070 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -752,13 +752,13 @@ def run_fwd_bwd_vschedule_with_optim(test_config): @parameterize( "config", [ - # # Pass + # Pass (1, 2, 1, 1, 2), + (1, 1, 2, 2, 1), + (1, 2, 1, 2, 1), # TODO: adapt mixtral with no TP Linear # (1, 2, 2, 1, 1), # (0, 1, 4, 1, 1), - # (1, 1, 2, 2, 1), - # (1, 2, 1, 2, 1), ], ) def run_with_booster_moehybridplugin(config: Tuple[int, ...]): @@ -1070,10 +1070,12 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): # parall_name = ".".join(parall_name.split(".")[1:]) # for base_name, base_param in torch_model.named_parameters(): # if parall_name == base_name: - # # assert weight + # # print(f"parall_name {parall_name} parall_param.grad {parall_param.grad is not None}, base_name {base_name} base_param.grad {base_param.grad is not None}") + # # # assert weight # assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name) - # # assert weight.grad + # # # assert weight.grad # if parall_param.grad is not None: + # # print(f"parall_param.grad {parall_param.grad}, base_param.grad {base_param.grad}") # assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad") assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) From 80b04d78550f370c9293195947bab0033d363f31 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 12 Nov 2024 07:28:49 +0000 Subject: [PATCH 05/16] [feat] support mixtral policy with zbv tp_Linear & non_tp_Linear --- .../pipeline/schedule/zero_bubble_pp.py | 3 +- colossalai/shardformer/policies/llama.py | 8 --- colossalai/shardformer/policies/mistral.py | 71 +++++++++++++++++++ .../test_schedule/test_zerobubble_pp.py | 46 ++++++------ 4 files changed, 99 insertions(+), 29 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 58b36f624..f678d7d7f 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -45,7 +45,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): num_model_chunks: int, num_microbatch: Optional[int] = None, microbatch_size: Optional[int] = None, - enable_metadata_cache: bool = True, + enable_metadata_cache: bool = False, overlap_p2p: bool = True, ): super().__init__(stage_manager) @@ -679,6 +679,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss=accum_loss, outputs=outputs, ) + # print(f"stage {self.stage_manager.stage}; chunk {model_chunk_id}; output_obj {output_obj}") # Step3: # 3-1:detach output; detach output for send fwd; diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 9640d8187..b18aa933c 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -194,15 +194,7 @@ class LlamaPolicy(Policy): # not enable tp, replace layer to LinearWithGradAccum elif use_zbv: - decoder_attribute_replacement = { - "self_attn.hidden_size": self.model.config.hidden_size // tp_size, - "self_attn.num_heads": num_q_heads, - } - if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads - policy[LlamaDecoderLayer] = ModulePolicyDescription( - attribute_replacement=decoder_attribute_replacement, sub_module_replacement=[ SubModuleReplacementDescription( suffix="self_attn.q_proj", diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 4d16038c1..b4b87df92 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -10,6 +10,7 @@ from colossalai.shardformer.layer import ( FusedRMSNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, VocabParallelEmbedding1D, @@ -62,6 +63,8 @@ class MistralPolicy(Policy): if self.tie_weight: embedding_cls = PaddingEmbedding + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( @@ -90,6 +93,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -97,6 +101,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -104,6 +109,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -111,6 +117,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -118,6 +125,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -125,6 +133,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -132,6 +141,68 @@ class MistralPolicy(Policy): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + elif use_zbv: + policy[MistralDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 3d0966070..ddb70e5f2 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -36,6 +36,24 @@ NUM_HEADS = 4 TOP_K = 1 +def register_hooks(module: torch.nn.Module): + + def fwd_hook(module, input, output): + torch.cuda.synchronize() + name = module._name if hasattr(module, "_name") else module + print(f"Fwd hook {name} \n output {output}") + + def bwd_hook(module, grad_input, grad_output): + torch.cuda.synchronize() + + def bwd_pre_hook(module, grad_output): + torch.cuda.synchronize() + + module.register_forward_hook(fwd_hook) + # module.register_backward_hook(bwd_hook) + # module.register_full_backward_pre_hook(bwd_pre_hook) + + class MlpModel(nn.Module): def __init__( self, @@ -756,9 +774,9 @@ def run_fwd_bwd_vschedule_with_optim(test_config): (1, 2, 1, 1, 2), (1, 1, 2, 2, 1), (1, 2, 1, 2, 1), - # TODO: adapt mixtral with no TP Linear - # (1, 2, 2, 1, 1), - # (0, 1, 4, 1, 1), + (1, 2, 2, 1, 1), + # # TODO: adapt mixtral with no TP Linear + (0, 1, 4, 1, 1), ], ) def run_with_booster_moehybridplugin(config: Tuple[int, ...]): @@ -872,7 +890,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): return_outputs=True, ) # stage 0 chunk 0 - parallel_output = None if ( booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) and rank == dist.get_process_group_ranks(plugin.pp_group)[0] @@ -880,6 +897,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): parallel_output = sharded_output["loss"] else: parallel_output = torch.tensor(12345.0, device="cuda") + print(f"rank {dist.get_rank()} parallel_output {parallel_output}") # broadcast along pp axis dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group) @@ -920,8 +938,8 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): (1, 2, 2, 1), (1, 2, 1, 2), (1, 1, 2, 2), - # TODO: acc err in pp4 - # (1, 4, 1, 1), + # TODO: support overlap p2p in pp4 + (1, 4, 1, 1), ], ) def run_with_booster_hybridplugin(config: Tuple[int, ...]): @@ -1030,7 +1048,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): return_outputs=True, ) # stage 0 chunk 0 - parallel_output = None if ( booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) and rank == dist.get_process_group_ranks(plugin.pp_group)[0] @@ -1054,6 +1071,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): all_inputs = [input_embeddings.clone() for _ in range(dp_size)] dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) torch_output_sum = 0 + # torch_model.apply(register_hooks) # register hook for base model for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output.backward() @@ -1065,19 +1083,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): torch_optimizer.step() torch_optimizer.zero_grad() - # # assert param - # for parall_name, parall_param in parallel_model.named_parameters(): - # parall_name = ".".join(parall_name.split(".")[1:]) - # for base_name, base_param in torch_model.named_parameters(): - # if parall_name == base_name: - # # print(f"parall_name {parall_name} parall_param.grad {parall_param.grad is not None}, base_name {base_name} base_param.grad {base_param.grad is not None}") - # # # assert weight - # assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name) - # # # assert weight.grad - # if parall_param.grad is not None: - # # print(f"parall_param.grad {parall_param.grad}, base_param.grad {base_param.grad}") - # assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad") - + print(f"parallel_output {parallel_output}, torch_output_sum {torch_output_sum}") assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed") clear_layout_converter() From b6d5e618093ae2abe55729a4f9ec1ffab2710598 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 14 Nov 2024 02:51:34 +0000 Subject: [PATCH 06/16] [feat] update mixtral policy & bert policy for zerobubble --- colossalai/shardformer/policies/bert.py | 98 ++++++++++++++++++++++ colossalai/shardformer/policies/mixtral.py | 78 ++++++++++++++++- 2 files changed, 173 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 09673d396..63cd49280 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -75,6 +75,8 @@ class BertPolicy(Policy): sp_partial_derived = sp_mode == "split_gather" + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -97,6 +99,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -105,6 +108,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -113,6 +117,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -125,6 +130,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -138,6 +144,7 @@ class BertPolicy(Policy): "seq_parallel_mode": sp_mode, "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -146,6 +153,97 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + policy[BertEmbeddings] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ] + ) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bert_intermediate_forward(), + }, + policy=policy, + target_key=BertIntermediate, + ) + + elif use_zbv: + policy[BertLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.self.query", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.self.key", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.self.value", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index ece72d929..54cd612f9 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -7,9 +7,18 @@ from torch import Tensor from torch.nn import Module from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col -from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D -from colossalai.shardformer.layer.linear import Linear1D_Row +from colossalai.shardformer.layer import ( + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + LinearWithGradAccum, + PaddingEmbedding, + VocabParallelEmbedding1D, +) + +# from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +# from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D +# from colossalai.shardformer.layer.linear import Linear1D_Row from colossalai.shardformer.modeling.mixtral import ( EPMixtralSparseMoeBlock, MixtralPipelineForwards, @@ -166,6 +175,52 @@ class MixtralPolicy(Policy): ], ) + elif use_zbv: + policy[MixtralDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="block_sparse_moe.gate", + target_module=LinearWithGradAccum, + kwargs={ + "gather_output": True, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -351,6 +406,23 @@ class MixtralForCausalLMPolicy(MixtralPolicy): ) } policy.update(new_item) + elif use_zbv: + new_item = { + MixtralForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=LinearWithGradAccum, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ], + ) + } + policy.update(new_item) if self.pipeline_stage_manager: # set None as default From 1bc4dba3a3a8911f05eea8c8eb68cf5807ca75c8 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 14 Nov 2024 09:40:38 +0000 Subject: [PATCH 07/16] [fix] fix p2p error in zbv --- colossalai/pipeline/schedule/zero_bubble_pp.py | 8 +++----- tests/test_pipeline/test_schedule/test_zerobubble_pp.py | 5 +---- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index f678d7d7f..31e6cfb38 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -45,10 +45,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): num_model_chunks: int, num_microbatch: Optional[int] = None, microbatch_size: Optional[int] = None, - enable_metadata_cache: bool = False, - overlap_p2p: bool = True, + enable_metadata_cache: bool = True, + overlap_p2p: bool = False, ): super().__init__(stage_manager) + # Not support overlap_p2p so far # batch info self.num_microbatch = num_microbatch self.microbatch_size = microbatch_size @@ -906,9 +907,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) - for h in self.wait_handles: - for hh in h: - hh.wait() # print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}") # return loss & output if outputs is not None: diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index ddb70e5f2..b630d30b1 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -770,13 +770,11 @@ def run_fwd_bwd_vschedule_with_optim(test_config): @parameterize( "config", [ - # Pass (1, 2, 1, 1, 2), (1, 1, 2, 2, 1), (1, 2, 1, 2, 1), (1, 2, 2, 1, 1), - # # TODO: adapt mixtral with no TP Linear - (0, 1, 4, 1, 1), + (1, 1, 4, 1, 1), ], ) def run_with_booster_moehybridplugin(config: Tuple[int, ...]): @@ -938,7 +936,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): (1, 2, 2, 1), (1, 2, 1, 2), (1, 1, 2, 2), - # TODO: support overlap p2p in pp4 (1, 4, 1, 1), ], ) From 014afbdb595a2ffa5271fd75ee9535ea3b533332 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 14 Nov 2024 09:43:47 +0000 Subject: [PATCH 08/16] [fix] fix attn --- colossalai/shardformer/layer/attn.py | 63 ++++++++++++++++++---------- 1 file changed, 41 insertions(+), 22 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 3202ebf25..019a6b140 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -6,6 +6,7 @@ import torch.distributed import torch.distributed as dist import torch.nn.functional as F from einops import rearrange +from packaging import version from colossalai.kernel.kernel_loader import ( FlashAttentionDaoLoader, @@ -642,9 +643,7 @@ class RingAttention(torch.autograd.Function): max_seqlen_q = max_seqlen_kv = max_seqlen cu_seqlens_half = cu_seqlens // 2 max_seqlen_half = max_seqlen // 2 - misc_kwargs = { - "window_size": (-1, -1), "alibi_slopes": None, "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, "dropout_p": dropout_p, @@ -652,6 +651,13 @@ class RingAttention(torch.autograd.Function): "softcap": 0.0, "return_softmax": False, } + import flash_attn + + if version.parse(flash_attn.__version__) > version.parse("2.6.3"): + misc_kwargs["window_size_left"] = -1 + misc_kwargs["window_size_right"] = -1 + else: + misc_kwargs["window_size"] = (-1, -1) if ( RingAttention.HALF_INDICES is not None @@ -707,26 +713,39 @@ class RingAttention(torch.autograd.Function): # Helper to pass args to FA def _forward(q, k, v, causal): - ( - _, - _, - _, - _, - out, - softmax_lse, - _, - rng_state, - ) = _flash_attn_forward( - q, - k, - v, - cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, - cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, - max_seqlen_q if q.shape[0] == t else max_seqlen_half, - max_seqlen_kv if k.shape[0] == t else max_seqlen_half, - causal=causal, - **misc_kwargs, - ) + if version.parse(flash_attn.__version__) > version.parse("2.6.3"): + (out, softmax_lse, S_dmask, rng_state) = _flash_attn_forward( + q, + k, + v, + cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, + cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, + max_seqlen_q if q.shape[0] == t else max_seqlen_half, + max_seqlen_kv if k.shape[0] == t else max_seqlen_half, + causal=causal, + **misc_kwargs, + ) + else: + ( + _, + _, + _, + _, + out, + softmax_lse, + _, + rng_state, + ) = _flash_attn_forward( + q, + k, + v, + cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, + cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, + max_seqlen_q if q.shape[0] == t else max_seqlen_half, + max_seqlen_kv if k.shape[0] == t else max_seqlen_half, + causal=causal, + **misc_kwargs, + ) return out, softmax_lse, rng_state def _kv_comm(i): From 5c2ebbfd48ad7590b0278687db2e41ab99e398d4 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 15 Nov 2024 05:58:56 +0000 Subject: [PATCH 09/16] [fix] fix mixtral modeling & policy; update wait handles; doing benchmarking for llama hybrid; --- colossalai/pipeline/schedule/zero_bubble_pp.py | 13 +++++++++++-- colossalai/shardformer/modeling/mixtral.py | 1 - colossalai/shardformer/policies/mixtral.py | 2 -- examples/language/mixtral/benchmark.py | 2 +- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 31e6cfb38..97ad9d5f5 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -46,7 +46,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): num_microbatch: Optional[int] = None, microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, - overlap_p2p: bool = False, + overlap_p2p: bool = True, ): super().__init__(stage_manager) # Not support overlap_p2p so far @@ -879,12 +879,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) for it in range(len(schedule)): scheduled_node = schedule[it] + # print(f"stage {self.stage_manager.stage} {scheduled_node.type}") if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication communication_func = self.communication_map[scheduled_node.type] wait_handle = communication_func(scheduled_node.chunk) self.wait_handles.append(wait_handle) elif scheduled_node.type == "F": + for h in self.wait_handles: + for hh in h: + hh.wait() self.schedule_f( scheduled_node=scheduled_node, model_chunk=model_chunk, @@ -894,6 +898,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): outputs=outputs, ) elif scheduled_node.type == "B": + for h in self.wait_handles: + for hh in h: + hh.wait() self.schedule_b( scheduled_node=scheduled_node, model_chunk=model_chunk, @@ -907,7 +914,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) - # print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}") + for h in self.wait_handles: + for hh in h: + hh.wait() # return loss & output if outputs is not None: outputs = merge_batch(outputs) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 3687cfb99..a88db87bc 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -381,7 +381,6 @@ class MixtralPipelineForwards: output_router_logits, use_cache, ) - hidden_states = layer_outputs[0] if use_cache: diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 54cd612f9..fab437c01 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -214,7 +214,6 @@ class MixtralPolicy(Policy): suffix="block_sparse_moe.gate", target_module=LinearWithGradAccum, kwargs={ - "gather_output": True, "fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv, }, @@ -414,7 +413,6 @@ class MixtralForCausalLMPolicy(MixtralPolicy): suffix="lm_head", target_module=LinearWithGradAccum, kwargs=dict( - gather_output=True, fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv, ), diff --git a/examples/language/mixtral/benchmark.py b/examples/language/mixtral/benchmark.py index 0334bd81c..dbffd0c2a 100644 --- a/examples/language/mixtral/benchmark.py +++ b/examples/language/mixtral/benchmark.py @@ -122,7 +122,7 @@ def main(): num_ckpt_layers_per_stage=[19, 19, 19, 13], ), "num_layers_per_stage": [19, 20, 20, 21], - # "pp_style": "interleaved", + "pp_style": "interleaved", } if args.custom_ckpt else {} From cf86c1b1c56169a6ea65432619d7675d4f6b0f7b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 15 Nov 2024 07:56:14 +0000 Subject: [PATCH 10/16] [fix] fix zbv wait_handle --- .../pipeline/schedule/zero_bubble_pp.py | 46 +++++++++++-------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 97ad9d5f5..0a97c466a 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -115,10 +115,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self.output_tensors_grad_dw = [[], []] # buffer for communication - self.send_forward_buffer = [[], []] - self.recv_forward_buffer = [[], []] - self.send_backward_buffer = [[], []] - self.recv_backward_buffer = [[], []] + self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]] + self.recv_forward_buffer = [ + [], + [], + ] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]] + self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]] + self.recv_backward_buffer = [ + [], + [], + ] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]] # y buffer for local send fwd self.local_send_forward_buffer = [] @@ -257,7 +263,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) - self.recv_forward_buffer[model_chunk_id].append(input_tensor) + self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles)) return wait_handles else: @@ -280,7 +286,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) - self.recv_forward_buffer[model_chunk_id].append(input_tensor) + self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles)) return wait_handles def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List: @@ -316,7 +322,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) - self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) + self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles)) return wait_handles else: @@ -339,7 +345,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) - self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) + self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles)) return wait_handles def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: @@ -651,9 +657,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if model_chunk_id == 0: # is first stage; get input from microbatch if self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj = None + input_obj = None # (tensor, wait_handle) else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) + for h in input_obj[1]: + h.wait() + input_obj = input_obj[0] else: # is last stage; recv from local if self.stage_manager.is_last_stage(ignore_chunk=True): @@ -661,7 +670,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # not last stage; recv from next else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) - + for h in input_obj[1]: + h.wait() + input_obj = input_obj[0] # Here, let input_obj.requires_grad_() # if input_obj is not None: if not isinstance(input_obj, torch.Tensor): @@ -751,6 +762,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # chunk0 not last stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + for h in output_tensor_grad[1]: + h.wait() + output_tensor_grad = output_tensor_grad[0] else: # chunk1, is first stage; recv LOSS from local send bwd buffer if self.stage_manager.is_first_stage(ignore_chunk=True): @@ -758,6 +772,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # chunk1, not first stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + for h in output_tensor_grad[1]: + h.wait() + output_tensor_grad = output_tensor_grad[0] # get input and output object from buffer; input_obj = self.input_tensors[model_chunk_id].pop(0) @@ -886,9 +903,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): wait_handle = communication_func(scheduled_node.chunk) self.wait_handles.append(wait_handle) elif scheduled_node.type == "F": - for h in self.wait_handles: - for hh in h: - hh.wait() self.schedule_f( scheduled_node=scheduled_node, model_chunk=model_chunk, @@ -898,9 +912,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): outputs=outputs, ) elif scheduled_node.type == "B": - for h in self.wait_handles: - for hh in h: - hh.wait() self.schedule_b( scheduled_node=scheduled_node, model_chunk=model_chunk, @@ -914,9 +925,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) - for h in self.wait_handles: - for hh in h: - hh.wait() # return loss & output if outputs is not None: outputs = merge_batch(outputs) From 0fb500c7d404a8e2fe306135b9b21f1b786868d7 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 15 Nov 2024 09:47:05 +0000 Subject: [PATCH 11/16] [fix] rm debug info; update llama policy; update wait handle --- .../pipeline/schedule/zero_bubble_pp.py | 6 ++- colossalai/shardformer/policies/llama.py | 37 +++++++++---------- .../test_schedule/test_zerobubble_pp.py | 19 ---------- 3 files changed, 22 insertions(+), 40 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 0a97c466a..92d214bad 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -691,7 +691,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss=accum_loss, outputs=outputs, ) - # print(f"stage {self.stage_manager.stage}; chunk {model_chunk_id}; output_obj {output_obj}") # Step3: # 3-1:detach output; detach output for send fwd; @@ -896,7 +895,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) for it in range(len(schedule)): scheduled_node = schedule[it] - # print(f"stage {self.stage_manager.stage} {scheduled_node.type}") if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication communication_func = self.communication_map[scheduled_node.type] @@ -925,6 +923,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) + # wait here to ensure all communication is done + for h in self.wait_handles: + for hh in h: + hh.wait() # return loss & output if outputs is not None: outputs = merge_batch(outputs) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b18aa933c..d962057b1 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -506,25 +506,24 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): ) } policy.update(new_item) - # TODO: test lora bug here - # # enable tp, replace layer to LinearWithGradAccum - # else: - # # add a new item for sequence classification - # new_item = { - # LlamaForSequenceClassification: ModulePolicyDescription( - # sub_module_replacement=[ - # SubModuleReplacementDescription( - # suffix="score", - # target_module=LinearWithGradAccum, - # kwargs=dict( - # fp8_communication=self.shard_config.fp8_communication, - # use_zbv=use_zbv, - # ), - # ) - # ] - # ) - # } - # policy.update(new_item) + # enable tp, replace layer to LinearWithGradAccum + else: + # add a new item for sequence classification + new_item = { + LlamaForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", + target_module=LinearWithGradAccum, + kwargs=dict( + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ] + ) + } + policy.update(new_item) # to be confirmed if self.pipeline_stage_manager: diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index b630d30b1..ba6e82e88 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -36,24 +36,6 @@ NUM_HEADS = 4 TOP_K = 1 -def register_hooks(module: torch.nn.Module): - - def fwd_hook(module, input, output): - torch.cuda.synchronize() - name = module._name if hasattr(module, "_name") else module - print(f"Fwd hook {name} \n output {output}") - - def bwd_hook(module, grad_input, grad_output): - torch.cuda.synchronize() - - def bwd_pre_hook(module, grad_output): - torch.cuda.synchronize() - - module.register_forward_hook(fwd_hook) - # module.register_backward_hook(bwd_hook) - # module.register_full_backward_pre_hook(bwd_pre_hook) - - class MlpModel(nn.Module): def __init__( self, @@ -1068,7 +1050,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): all_inputs = [input_embeddings.clone() for _ in range(dp_size)] dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) torch_output_sum = 0 - # torch_model.apply(register_hooks) # register hook for base model for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output.backward() From 2980da559fb95fc6fc765eb86243c9f56654ffc8 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 15 Nov 2024 10:26:30 +0000 Subject: [PATCH 12/16] [fix] fix test_lora --- colossalai/shardformer/policies/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index d962057b1..b4a1f4bd8 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -506,8 +506,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): ) } policy.update(new_item) + # TODO: test lora bug here # enable tp, replace layer to LinearWithGradAccum - else: + elif use_zbv: # add a new item for sequence classification new_item = { LlamaForSequenceClassification: ModulePolicyDescription( From f48a85e91d88133389ee53bdcc7fbd5dad982b9d Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 15 Nov 2024 10:27:13 +0000 Subject: [PATCH 13/16] [fix] fix test_lora in llama policy --- colossalai/shardformer/policies/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b4a1f4bd8..e8f9471f9 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -506,7 +506,6 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): ) } policy.update(new_item) - # TODO: test lora bug here # enable tp, replace layer to LinearWithGradAccum elif use_zbv: # add a new item for sequence classification From 9a21f87ed6e161b88378490c026210b4f261c98b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 18 Nov 2024 02:50:14 +0000 Subject: [PATCH 14/16] [fix] fix wait handle in run_fwd_bwd --- colossalai/pipeline/schedule/zero_bubble_pp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 92d214bad..498240878 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -899,7 +899,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # communication communication_func = self.communication_map[scheduled_node.type] wait_handle = communication_func(scheduled_node.chunk) - self.wait_handles.append(wait_handle) + # We wait recv handle in fwd step and bwd step. Here only need to wait for send handle + if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}: + self.wait_handles.append(wait_handle) elif scheduled_node.type == "F": self.schedule_f( scheduled_node=scheduled_node, From dafda0fb7082506ad76b5deff3024b3d5dbb904b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 18 Nov 2024 03:32:04 +0000 Subject: [PATCH 15/16] [fix] remove debug info; --- tests/test_pipeline/test_schedule/test_zerobubble_pp.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index ba6e82e88..a01b75eee 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -760,7 +760,6 @@ def run_fwd_bwd_vschedule_with_optim(test_config): ], ) def run_with_booster_moehybridplugin(config: Tuple[int, ...]): - test_config = config stage, ep_size, pp_size, tp_size, sp_size = config num_microbatches = pp_size dist.get_world_size() @@ -877,7 +876,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): parallel_output = sharded_output["loss"] else: parallel_output = torch.tensor(12345.0, device="cuda") - print(f"rank {dist.get_rank()} parallel_output {parallel_output}") # broadcast along pp axis dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group) @@ -905,7 +903,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): torch_optimizer.step() torch_optimizer.zero_grad() assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) - print(f"rank {dist.get_rank()} config {test_config} test passed") clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() @@ -1060,10 +1057,8 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): p.grad /= dp_size torch_optimizer.step() torch_optimizer.zero_grad() - - print(f"parallel_output {parallel_output}, torch_output_sum {torch_output_sum}") assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) - print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed") + clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() From 41fdd2139ba60e4305c701d25b7bf88d1e4d223b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 18 Nov 2024 16:48:21 +0800 Subject: [PATCH 16/16] [fix] rm unused comments --- colossalai/pipeline/schedule/zero_bubble_pp.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 498240878..89c868aae 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -49,7 +49,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): overlap_p2p: bool = True, ): super().__init__(stage_manager) - # Not support overlap_p2p so far # batch info self.num_microbatch = num_microbatch self.microbatch_size = microbatch_size @@ -543,8 +542,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_grad_ = [] # For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx. - # if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - # return None # For loss backward; output_obj is loss; output_obj_grad should be None if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): @@ -718,10 +715,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Do not release_tensor_data loss, release_tensor_data other output_obj; if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): self.output_tensors[model_chunk_id].append(output_obj) - # self.output_tensors_dw[model_chunk_id].append(output_obj) else: self.output_tensors[model_chunk_id].append(output_obj) - # self.output_tensors_dw[model_chunk_id].append(output_obj) # add output to send_fwd_buffer if model_chunk_id == 0: # chunk 0