From ce58d8e8bf8c8807eb37b29fff8495b155279274 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 9 Sep 2024 08:19:58 +0000 Subject: [PATCH] [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; --- colossalai/pipeline/schedule/zero_bubble_pp.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 622e7eb08..c1c4f13c6 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -475,8 +475,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): tree_map(retain_grad, input_obj) if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # loss backward; output_obj is loss - output_obj_grad = None + # loss backward; output_obj is loss; so output_obj_grad should be None + assert output_obj_grad is None + optimizer.backward_by_grad( tensor=output_obj, grad=output_obj_grad, @@ -554,7 +555,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # not last stage; recv from next else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) - input_obj.requires_grad_() + + # Here, let input_obj.requires_grad_() + tree_map(torch.Tensor.requires_grad_, input_obj) # Step2: fwd step output_obj = self.forward_step(