mirror of https://github.com/hpcaitech/ColossalAI
[util] fixed activation checkpointing on torch 1.9 (#719)
parent
04ff5ea546
commit
2412429d54
|
@ -68,7 +68,10 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
else:
|
else:
|
||||||
ctx.inputs.append(arg)
|
ctx.inputs.append(arg)
|
||||||
|
|
||||||
ctx.save_for_backward(*tensor_inputs)
|
if activation_offload:
|
||||||
|
ctx.tensor_inputs = tensor_inputs
|
||||||
|
else:
|
||||||
|
ctx.save_for_backward(*tensor_inputs)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -79,7 +82,11 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
# Copy the list to avoid modifying original list.
|
# Copy the list to avoid modifying original list.
|
||||||
inputs = list(ctx.inputs)
|
inputs = list(ctx.inputs)
|
||||||
tensor_indices = ctx.tensor_indices
|
tensor_indices = ctx.tensor_indices
|
||||||
tensors = ctx.saved_tensors
|
|
||||||
|
if ctx.activation_offload:
|
||||||
|
tensors = ctx.tensor_inputs
|
||||||
|
else:
|
||||||
|
tensors = ctx.saved_tensors
|
||||||
|
|
||||||
# store the current states
|
# store the current states
|
||||||
bwd_cpu_rng_state = torch.get_rng_state()
|
bwd_cpu_rng_state = torch.get_rng_state()
|
||||||
|
|
Loading…
Reference in New Issue