[pipeline] hotfix backward for multiple outputs (#6090)

* [pipeline] hotfix backward for multiple outputs

* [pipeline] hotfix backward for multiple outputs
pull/6092/head
Hongxin Liu 2024-10-16 17:27:33 +08:00 committed by GitHub
parent 62c13e7969
commit cd61353bae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 16 deletions

View File

@ -351,15 +351,16 @@ class InterleavedSchedule(PipelineSchedule):
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
if "backward_tensor_keys" not in output_obj:
for k, grad in output_obj_grad.items():
optimizer.backward_by_grad(output_obj[k], grad)
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
tensors_to_backward = []
grads_to_backward = []
for k in keys:
tensors_to_backward.append(output_obj[k])
grads_to_backward.append(output_obj_grad[k])
if len(tensors_to_backward) == 1:
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
else:
for k, grad in output_obj_grad.items():
output_obj[k].grad = grad
for k in output_obj["backward_tensor_keys"]:
tensor_to_backward = output_obj[k]
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
# Collect the grad of the input_obj.
input_obj_grad = None

View File

@ -305,15 +305,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
if "backward_tensor_keys" not in output_obj:
for k, grad in output_obj_grad.items():
optimizer.backward_by_grad(output_obj[k], grad)
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
tensors_to_backward = []
grads_to_backward = []
for k in keys:
tensors_to_backward.append(output_obj[k])
grads_to_backward.append(output_obj_grad[k])
if len(tensors_to_backward) == 1:
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
else:
for k, grad in output_obj_grad.items():
output_obj[k].grad = grad
for k in output_obj["backward_tensor_keys"]:
tensor_to_backward = output_obj[k]
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
# Collect the grad of the input_obj.
input_obj_grad = None