diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 8fcb2aa56..1af62cc8a 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -489,12 +489,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # For chunk 0 stage 0, use micro_batch as input_obj_ if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - # for k, v in micro_batch.items(): - # if v.requires_grad: - # input_obj_.append(micro_batch[k]) - # output_obj_.append(output_obj[k]) # y - # output_obj_grad_.append(output_obj_grad[k]) # dy - input_obj_, _ = tree_flatten(micro_batch) output_obj_, _ = tree_flatten(output_obj) # y output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy @@ -502,22 +496,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # For loss backward; output_obj is loss; output_obj_grad should be None elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): assert output_obj_grad is None - # for k, v in input_obj.items(): - # if v.requires_grad: - # input_obj_.append(input_obj[k]) input_obj_, _ = tree_flatten(input_obj) - # output_obj_.append(output_obj) # LOSS - # output_obj_grad_.append(output_obj_grad) # None output_obj_, _ = tree_flatten(output_obj) # LOSS output_obj_grad_, _ = tree_flatten(output_obj_grad) # None # For other chunk stage, use input_obj as input_obj_; else: - # for k, v in input_obj.items(): - # if v.requires_grad: - # input_obj_.append(input_obj[k]) - # output_obj_.append(output_obj[k]) # y - # output_obj_grad_.append(output_obj_grad[k]) # dy input_obj_, _ = tree_flatten(input_obj) output_obj_, _ = tree_flatten(output_obj) # y output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy @@ -572,10 +556,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_.append(output_obj) # LOSS output_obj_grad_.append(None) # None else: - # for k, v in output_obj.items(): - # if v.requires_grad: - # output_obj_.append(output_obj[k]) - # output_obj_grad_.append(output_obj_grad[k]) output_obj_, _ = tree_flatten(output_obj) # y output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy @@ -653,7 +633,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): tree_map(deallocate, deallocate_output_obj) # add input and output object for backward b - self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) # for bwd b&w, we only need the graph(grad_fn) of output_obj