From 1342a983b10a1d44632fce5545e3a1a107687082 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 26 Sep 2024 11:05:27 +0000 Subject: [PATCH] [fix] rm print & comments; --- colossalai/pipeline/schedule/zero_bubble_pp.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 8562d23f2..5c25c5bfa 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -478,11 +478,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_ = [] output_obj_grad_ = [] - # For chunk 0 stage 0, use micro_batch as input_obj_ + # For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx. if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - # input_obj_, _ = tree_flatten(micro_batch) - # output_obj_, _ = tree_flatten(output_obj) # y - # output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy return None # For loss backward; output_obj is loss; output_obj_grad should be None @@ -513,9 +510,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Format output_obj_grad input_obj_grad = {} if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - # for k, v in micro_batch.items(): - # if isinstance(v, torch.Tensor) and v.grad is not None: - # input_obj_grad[k] = v.grad pass else: for k, v in input_obj.items(): @@ -645,7 +639,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): tree_map(release_tensor_data, output_obj) # add input and output object for backward b - # self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) self.input_tensors[model_chunk_id].append(input_obj) # for bwd b&w, we only need the graph(grad_fn) of output_obj @@ -704,7 +697,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) # get input and output object from buffer; - # micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0) input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) @@ -841,7 +833,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # while we still have schedules_node in self.schedules schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) - print(f"schedule {schedule}") for it in range(len(schedule)): scheduled_node = schedule[it] if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: