diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index bab118b85..7bdb6d11e 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -226,7 +226,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; cause u are chunk 0 in first rank, u have no prev rank; ################# if self.stage_manager.is_first_stage(ignore_chunk=True): - # return None, [] return [] ################ @@ -241,7 +240,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor) - # return input_tensor, wait_handles return wait_handles else: @@ -265,7 +263,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor) - # return input_tensor, wait_handles return wait_handles def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List: @@ -313,7 +310,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; get loss from local ################ if self.stage_manager.is_first_stage(ignore_chunk=True): - # return None, [] return [] ################ @@ -328,7 +324,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) - # return output_tensor_grad, wait_handles return wait_handles def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: @@ -665,7 +660,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss=accum_loss, outputs=outputs, ) - # print(f"stage {self.stage_manager.stage}; model_chunk_id {model_chunk_id}; output_obj {output_obj};") # Step3: # 3-1:detach output; detach output for send fwd; @@ -748,20 +742,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) - # # save output_tensor_grad for dw - # if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # # we save loss here - # self.output_tensors_grad_dw[model_chunk_id].append(output_obj) - # else: - # # we save output_tensor_grad here - # self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) - # the_output_obj_grad = [] - # if isinstance(output_obj, dict): - # for (k, v) in output_obj.items(): - # the_output_obj_grad.append(v.requires_grad) - # else: - # the_output_obj_grad.append(output_obj.requires_grad) - input_object_grad = self.backward_b_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, @@ -804,20 +784,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): Returns: Nothing. """ - - # get y & dy from buffer - # output_obj = self.output_tensors_dw[model_chunk_id].pop(0) - # output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) WeightGradStore.pop(chunk=model_chunk_id) - # self.backward_w_step( - # model_chunk=model_chunk, - # model_chunk_id=model_chunk_id, - # optimizer=optimizer, - # output_obj=output_obj, - # output_obj_grad=output_obj_grad, - # ) - def run_forward_only( self, model_chunk: Union[ModuleList, Module], @@ -890,7 +858,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) for it in range(len(schedule)): scheduled_node = schedule[it] - # print(f"rank {torch.distributed.get_rank()}; stage {self.stage_manager.stage}; scheduled_node {scheduled_node};") if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication communication_func = self.communication_map[scheduled_node.type] diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index b8ef09bea..bda3a5512 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -749,12 +749,6 @@ def run_fwd_bwd_vschedule_with_optim(test_config): assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups) -# TODO:3) support booster & Hybrid base 2) -def run_with_hybridplugin(test_config): - pass - - -# TODO:4) support booster & MoEHybrid base 2) @parameterize( "config", [ @@ -923,9 +917,9 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): "config", [ # # Pass - (1, 2, 2, 1), - (1, 2, 1, 2), - (1, 1, 2, 2), + # (1, 2, 2, 1), + # (1, 2, 1, 2), + # (1, 1, 2, 2), # TODO: acc err in pp4 (1, 4, 1, 1), ], @@ -1071,6 +1065,17 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): torch_optimizer.step() torch_optimizer.zero_grad() + # assert param + for parall_name, parall_param in parallel_model.named_parameters(): + parall_name = ".".join(parall_name.split(".")[1:]) + for base_name, base_param in torch_model.named_parameters(): + if parall_name == base_name: + # assert weight + assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name) + # assert weight.grad + if parall_param.grad is not None: + assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad") + assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed") clear_layout_converter() @@ -1081,7 +1086,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_with_booster_moehybridplugin() + # run_with_booster_moehybridplugin() run_with_booster_hybridplugin() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 33707a4f6..c0690e5fd 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -420,4 +420,4 @@ def test_llama_3d(): if __name__ == "__main__": test_llama() - test_llama_3d() + # test_llama_3d()