diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index c538ee071..5da98364d 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -351,15 +351,16 @@ class InterleavedSchedule(PipelineSchedule): if output_obj_grad is None: optimizer.backward(output_obj) else: - if "backward_tensor_keys" not in output_obj: - for k, grad in output_obj_grad.items(): - optimizer.backward_by_grad(output_obj[k], grad) + keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys()) + tensors_to_backward = [] + grads_to_backward = [] + for k in keys: + tensors_to_backward.append(output_obj[k]) + grads_to_backward.append(output_obj_grad[k]) + if len(tensors_to_backward) == 1: + optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0]) else: - for k, grad in output_obj_grad.items(): - output_obj[k].grad = grad - for k in output_obj["backward_tensor_keys"]: - tensor_to_backward = output_obj[k] - optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad) + optimizer.backward_by_grad(tensors_to_backward, grads_to_backward) # Collect the grad of the input_obj. input_obj_grad = None diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 0fc90995a..224d63688 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -305,15 +305,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if output_obj_grad is None: optimizer.backward(output_obj) else: - if "backward_tensor_keys" not in output_obj: - for k, grad in output_obj_grad.items(): - optimizer.backward_by_grad(output_obj[k], grad) + keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys()) + tensors_to_backward = [] + grads_to_backward = [] + for k in keys: + tensors_to_backward.append(output_obj[k]) + grads_to_backward.append(output_obj_grad[k]) + if len(tensors_to_backward) == 1: + optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0]) else: - for k, grad in output_obj_grad.items(): - output_obj[k].grad = grad - for k in output_obj["backward_tensor_keys"]: - tensor_to_backward = output_obj[k] - optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad) + optimizer.backward_by_grad(tensors_to_backward, grads_to_backward) # Collect the grad of the input_obj. input_obj_grad = None