mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] hotfix backward for multiple outputs (#6090)
* [pipeline] hotfix backward for multiple outputs * [pipeline] hotfix backward for multiple outputspull/6092/head
parent
62c13e7969
commit
cd61353bae
|
@ -351,15 +351,16 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if output_obj_grad is None:
|
||||
optimizer.backward(output_obj)
|
||||
else:
|
||||
if "backward_tensor_keys" not in output_obj:
|
||||
for k, grad in output_obj_grad.items():
|
||||
optimizer.backward_by_grad(output_obj[k], grad)
|
||||
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
|
||||
tensors_to_backward = []
|
||||
grads_to_backward = []
|
||||
for k in keys:
|
||||
tensors_to_backward.append(output_obj[k])
|
||||
grads_to_backward.append(output_obj_grad[k])
|
||||
if len(tensors_to_backward) == 1:
|
||||
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
|
||||
else:
|
||||
for k, grad in output_obj_grad.items():
|
||||
output_obj[k].grad = grad
|
||||
for k in output_obj["backward_tensor_keys"]:
|
||||
tensor_to_backward = output_obj[k]
|
||||
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
|
||||
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
|
||||
|
||||
# Collect the grad of the input_obj.
|
||||
input_obj_grad = None
|
||||
|
|
|
@ -305,15 +305,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if output_obj_grad is None:
|
||||
optimizer.backward(output_obj)
|
||||
else:
|
||||
if "backward_tensor_keys" not in output_obj:
|
||||
for k, grad in output_obj_grad.items():
|
||||
optimizer.backward_by_grad(output_obj[k], grad)
|
||||
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
|
||||
tensors_to_backward = []
|
||||
grads_to_backward = []
|
||||
for k in keys:
|
||||
tensors_to_backward.append(output_obj[k])
|
||||
grads_to_backward.append(output_obj_grad[k])
|
||||
if len(tensors_to_backward) == 1:
|
||||
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
|
||||
else:
|
||||
for k, grad in output_obj_grad.items():
|
||||
output_obj[k].grad = grad
|
||||
for k in output_obj["backward_tensor_keys"]:
|
||||
tensor_to_backward = output_obj[k]
|
||||
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
|
||||
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
|
||||
|
||||
# Collect the grad of the input_obj.
|
||||
input_obj_grad = None
|
||||
|
|
Loading…
Reference in New Issue