From 2412429d540442febc4ef920cf9e06c54d0c8571 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 12 Apr 2022 09:35:45 +0800 Subject: [PATCH] [util] fixed activation checkpointing on torch 1.9 (#719) --- colossalai/utils/activation_checkpoint.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/utils/activation_checkpoint.py index 88cc7e202..2edd6b1a5 100644 --- a/colossalai/utils/activation_checkpoint.py +++ b/colossalai/utils/activation_checkpoint.py @@ -68,7 +68,10 @@ class CheckpointFunction(torch.autograd.Function): else: ctx.inputs.append(arg) - ctx.save_for_backward(*tensor_inputs) + if activation_offload: + ctx.tensor_inputs = tensor_inputs + else: + ctx.save_for_backward(*tensor_inputs) return outputs @staticmethod @@ -79,7 +82,11 @@ class CheckpointFunction(torch.autograd.Function): # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices - tensors = ctx.saved_tensors + + if ctx.activation_offload: + tensors = ctx.tensor_inputs + else: + tensors = ctx.saved_tensors # store the current states bwd_cpu_rng_state = torch.get_rng_state()