mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] hotfix backward for multiple outputs (#6090)
* [pipeline] hotfix backward for multiple outputs * [pipeline] hotfix backward for multiple outputspull/6092/head
parent
62c13e7969
commit
cd61353bae
|
@ -351,15 +351,16 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
if output_obj_grad is None:
|
if output_obj_grad is None:
|
||||||
optimizer.backward(output_obj)
|
optimizer.backward(output_obj)
|
||||||
else:
|
else:
|
||||||
if "backward_tensor_keys" not in output_obj:
|
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
|
||||||
for k, grad in output_obj_grad.items():
|
tensors_to_backward = []
|
||||||
optimizer.backward_by_grad(output_obj[k], grad)
|
grads_to_backward = []
|
||||||
|
for k in keys:
|
||||||
|
tensors_to_backward.append(output_obj[k])
|
||||||
|
grads_to_backward.append(output_obj_grad[k])
|
||||||
|
if len(tensors_to_backward) == 1:
|
||||||
|
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
|
||||||
else:
|
else:
|
||||||
for k, grad in output_obj_grad.items():
|
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
|
||||||
output_obj[k].grad = grad
|
|
||||||
for k in output_obj["backward_tensor_keys"]:
|
|
||||||
tensor_to_backward = output_obj[k]
|
|
||||||
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
|
|
||||||
|
|
||||||
# Collect the grad of the input_obj.
|
# Collect the grad of the input_obj.
|
||||||
input_obj_grad = None
|
input_obj_grad = None
|
||||||
|
|
|
@ -305,15 +305,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
if output_obj_grad is None:
|
if output_obj_grad is None:
|
||||||
optimizer.backward(output_obj)
|
optimizer.backward(output_obj)
|
||||||
else:
|
else:
|
||||||
if "backward_tensor_keys" not in output_obj:
|
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
|
||||||
for k, grad in output_obj_grad.items():
|
tensors_to_backward = []
|
||||||
optimizer.backward_by_grad(output_obj[k], grad)
|
grads_to_backward = []
|
||||||
|
for k in keys:
|
||||||
|
tensors_to_backward.append(output_obj[k])
|
||||||
|
grads_to_backward.append(output_obj_grad[k])
|
||||||
|
if len(tensors_to_backward) == 1:
|
||||||
|
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
|
||||||
else:
|
else:
|
||||||
for k, grad in output_obj_grad.items():
|
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
|
||||||
output_obj[k].grad = grad
|
|
||||||
for k in output_obj["backward_tensor_keys"]:
|
|
||||||
tensor_to_backward = output_obj[k]
|
|
||||||
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
|
|
||||||
|
|
||||||
# Collect the grad of the input_obj.
|
# Collect the grad of the input_obj.
|
||||||
input_obj_grad = None
|
input_obj_grad = None
|
||||||
|
|
Loading…
Reference in New Issue