diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 66fbc827b..8562d23f2 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -449,7 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, - micro_batch: Optional[dict], + # micro_batch: Optional[dict], input_obj: Optional[dict], output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], @@ -480,9 +480,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # For chunk 0 stage 0, use micro_batch as input_obj_ if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj_, _ = tree_flatten(micro_batch) - output_obj_, _ = tree_flatten(output_obj) # y - output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + # input_obj_, _ = tree_flatten(micro_batch) + # output_obj_, _ = tree_flatten(output_obj) # y + # output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + return None # For loss backward; output_obj is loss; output_obj_grad should be None elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): @@ -512,9 +513,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Format output_obj_grad input_obj_grad = {} if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - for k, v in micro_batch.items(): - if isinstance(v, torch.Tensor) and v.grad is not None: - input_obj_grad[k] = v.grad + # for k, v in micro_batch.items(): + # if isinstance(v, torch.Tensor) and v.grad is not None: + # input_obj_grad[k] = v.grad + pass else: for k, v in input_obj.items(): if isinstance(v, torch.Tensor) and v.grad is not None: @@ -643,7 +645,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): tree_map(release_tensor_data, output_obj) # add input and output object for backward b - self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) + # self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) + self.input_tensors[model_chunk_id].append(input_obj) # for bwd b&w, we only need the graph(grad_fn) of output_obj # Do not release_tensor_data loss, release_tensor_data other output_obj; @@ -701,7 +704,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) # get input and output object from buffer; - micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0) + # micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0) + input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) # save output_tensor_grad for dw @@ -717,7 +721,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk=model_chunk, model_chunk_id=model_chunk_id, optimizer=optimizer, - micro_batch=micro_batch, input_obj=input_obj, output_obj=output_obj, output_obj_grad=output_tensor_grad, @@ -838,6 +841,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # while we still have schedules_node in self.schedules schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) + print(f"schedule {schedule}") for it in range(len(schedule)): scheduled_node = schedule[it] if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: