diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/utils/activation_checkpoint.py index 88e6f1735..808d8149b 100644 --- a/colossalai/utils/activation_checkpoint.py +++ b/colossalai/utils/activation_checkpoint.py @@ -7,6 +7,19 @@ from torch.utils.checkpoint import check_backward_validity, detach_variable from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states from .cuda import get_current_device + +def copy_to_device(obj, device): + if torch.is_tensor(obj): + return obj.to(device) + elif isinstance(obj, list): + return [copy_to_device(i, device) for i in obj] + elif isinstance(obj, tuple): + return tuple([copy_to_device(v, device) for v in obj]) + elif isinstance(obj, dict): + return {k: copy_to_device(v, device) for k, v in obj.items()} + else: + return obj + class CheckpointFunction(torch.autograd.Function): @staticmethod @@ -26,7 +39,14 @@ class CheckpointFunction(torch.autograd.Function): ctx.had_autocast_in_fwd = torch.is_autocast_enabled() else: ctx.had_autocast_in_fwd = False + + if activation_offload: + inputs_cuda = copy_to_device(args, ctx.device) + else: + inputs_cuda = args + with torch.no_grad(): + outputs = run_function(*inputs_cuda) # Save non-tensor inputs in ctx, keep a placeholder None for tensors # to be filled out during the backward. ctx.inputs = [] @@ -34,10 +54,8 @@ class CheckpointFunction(torch.autograd.Function): tensor_inputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): - if ctx.activation_offload: - tmp = arg.detach().cpu() - tmp.requires_grad = arg.requires_grad - tensor_inputs.append(tmp) + if activation_offload: + tensor_inputs.append(copy_to_device(arg, 'cpu')) else: tensor_inputs.append(arg) ctx.tensor_indices.append(i) @@ -46,18 +64,15 @@ class CheckpointFunction(torch.autograd.Function): ctx.inputs.append(arg) ctx.save_for_backward(*tensor_inputs) - - with torch.no_grad(): - outputs = run_function(*args) return outputs @staticmethod def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( - "Checkpointing is not compatible with .grad() or when an `inputs` parameter" - " is passed to .backward(). Please use .backward() and do not pass its `inputs`" - " argument.") + "Checkpointing is not compatible with .grad() or when an `inputs` parameter is " + "passed to .backward(). Please use .backward() and do not pass its `inputs` argument." + ) # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices @@ -74,12 +89,12 @@ class CheckpointFunction(torch.autograd.Function): for parallel_mode, state in ctx.fwd_seed_states.items(): set_seed_states(parallel_mode, state) set_mode(ctx.fwd_current_mode) + if ctx.activation_offload: + tensors = copy_to_device(tensors, ctx.device) # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): - tmp = tensors[i].detach().to(ctx.device) - tmp.requires_grad = tensors[i].requires_grad - inputs[idx] = tmp + inputs[idx] = tensors[i] detached_inputs = detach_variable(tuple(inputs)) if ctx.had_autocast_in_fwd: with torch.enable_grad(), torch.cuda.amp.autocast():