[fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

pull/6034/head
duanjunwen 3 months ago
parent 7568b34626
commit ce58d8e8bf

@ -475,8 +475,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
tree_map(retain_grad, input_obj)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
output_obj_grad = None
# loss backward; output_obj is loss; so output_obj_grad should be None
assert output_obj_grad is None
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
@ -554,7 +555,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# not last stage; recv from next
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
input_obj.requires_grad_()
# Here, let input_obj.requires_grad_()
tree_map(torch.Tensor.requires_grad_, input_obj)
# Step2: fwd step
output_obj = self.forward_step(

Loading…
Cancel
Save