[util] fixed activation checkpointing on torch 1.9 (#719)

pull/729/head
Frank Lee 3 years ago committed by GitHub
parent 04ff5ea546
commit 2412429d54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -68,6 +68,9 @@ class CheckpointFunction(torch.autograd.Function):
else:
ctx.inputs.append(arg)
if activation_offload:
ctx.tensor_inputs = tensor_inputs
else:
ctx.save_for_backward(*tensor_inputs)
return outputs
@ -79,6 +82,10 @@ class CheckpointFunction(torch.autograd.Function):
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
if ctx.activation_offload:
tensors = ctx.tensor_inputs
else:
tensors = ctx.saved_tensors
# store the current states

Loading…
Cancel
Save