diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/utils/activation_checkpoint.py index 88cc7e202..2edd6b1a5 100644 --- a/colossalai/utils/activation_checkpoint.py +++ b/colossalai/utils/activation_checkpoint.py @@ -68,7 +68,10 @@ class CheckpointFunction(torch.autograd.Function): else: ctx.inputs.append(arg) - ctx.save_for_backward(*tensor_inputs) + if activation_offload: + ctx.tensor_inputs = tensor_inputs + else: + ctx.save_for_backward(*tensor_inputs) return outputs @staticmethod @@ -79,7 +82,11 @@ class CheckpointFunction(torch.autograd.Function): # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices - tensors = ctx.saved_tensors + + if ctx.activation_offload: + tensors = ctx.tensor_inputs + else: + tensors = ctx.saved_tensors # store the current states bwd_cpu_rng_state = torch.get_rng_state()