|
|
@ -7,6 +7,19 @@ from torch.utils.checkpoint import check_backward_validity, detach_variable
|
|
|
|
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
|
|
|
|
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
|
|
|
|
from .cuda import get_current_device
|
|
|
|
from .cuda import get_current_device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def copy_to_device(obj, device):
|
|
|
|
|
|
|
|
if torch.is_tensor(obj):
|
|
|
|
|
|
|
|
return obj.to(device)
|
|
|
|
|
|
|
|
elif isinstance(obj, list):
|
|
|
|
|
|
|
|
return [copy_to_device(i, device) for i in obj]
|
|
|
|
|
|
|
|
elif isinstance(obj, tuple):
|
|
|
|
|
|
|
|
return tuple([copy_to_device(v, device) for v in obj])
|
|
|
|
|
|
|
|
elif isinstance(obj, dict):
|
|
|
|
|
|
|
|
return {k: copy_to_device(v, device) for k, v in obj.items()}
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function):
|
|
|
|
class CheckpointFunction(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
@ -26,7 +39,14 @@ class CheckpointFunction(torch.autograd.Function):
|
|
|
|
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
|
|
|
|
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
ctx.had_autocast_in_fwd = False
|
|
|
|
ctx.had_autocast_in_fwd = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if activation_offload:
|
|
|
|
|
|
|
|
inputs_cuda = copy_to_device(args, ctx.device)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
inputs_cuda = args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
|
|
outputs = run_function(*inputs_cuda)
|
|
|
|
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
|
|
|
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
|
|
|
# to be filled out during the backward.
|
|
|
|
# to be filled out during the backward.
|
|
|
|
ctx.inputs = []
|
|
|
|
ctx.inputs = []
|
|
|
@ -34,10 +54,8 @@ class CheckpointFunction(torch.autograd.Function):
|
|
|
|
tensor_inputs = []
|
|
|
|
tensor_inputs = []
|
|
|
|
for i, arg in enumerate(args):
|
|
|
|
for i, arg in enumerate(args):
|
|
|
|
if torch.is_tensor(arg):
|
|
|
|
if torch.is_tensor(arg):
|
|
|
|
if ctx.activation_offload:
|
|
|
|
if activation_offload:
|
|
|
|
tmp = arg.detach().cpu()
|
|
|
|
tensor_inputs.append(copy_to_device(arg, 'cpu'))
|
|
|
|
tmp.requires_grad = arg.requires_grad
|
|
|
|
|
|
|
|
tensor_inputs.append(tmp)
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
tensor_inputs.append(arg)
|
|
|
|
tensor_inputs.append(arg)
|
|
|
|
ctx.tensor_indices.append(i)
|
|
|
|
ctx.tensor_indices.append(i)
|
|
|
@ -46,18 +64,15 @@ class CheckpointFunction(torch.autograd.Function):
|
|
|
|
ctx.inputs.append(arg)
|
|
|
|
ctx.inputs.append(arg)
|
|
|
|
|
|
|
|
|
|
|
|
ctx.save_for_backward(*tensor_inputs)
|
|
|
|
ctx.save_for_backward(*tensor_inputs)
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
|
|
outputs = run_function(*args)
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, *args):
|
|
|
|
def backward(ctx, *args):
|
|
|
|
if not torch.autograd._is_checkpoint_valid():
|
|
|
|
if not torch.autograd._is_checkpoint_valid():
|
|
|
|
raise RuntimeError(
|
|
|
|
raise RuntimeError(
|
|
|
|
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
|
|
|
"Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
|
|
|
|
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
|
|
|
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument."
|
|
|
|
" argument.")
|
|
|
|
)
|
|
|
|
# Copy the list to avoid modifying original list.
|
|
|
|
# Copy the list to avoid modifying original list.
|
|
|
|
inputs = list(ctx.inputs)
|
|
|
|
inputs = list(ctx.inputs)
|
|
|
|
tensor_indices = ctx.tensor_indices
|
|
|
|
tensor_indices = ctx.tensor_indices
|
|
|
@ -74,12 +89,12 @@ class CheckpointFunction(torch.autograd.Function):
|
|
|
|
for parallel_mode, state in ctx.fwd_seed_states.items():
|
|
|
|
for parallel_mode, state in ctx.fwd_seed_states.items():
|
|
|
|
set_seed_states(parallel_mode, state)
|
|
|
|
set_seed_states(parallel_mode, state)
|
|
|
|
set_mode(ctx.fwd_current_mode)
|
|
|
|
set_mode(ctx.fwd_current_mode)
|
|
|
|
|
|
|
|
if ctx.activation_offload:
|
|
|
|
|
|
|
|
tensors = copy_to_device(tensors, ctx.device)
|
|
|
|
|
|
|
|
|
|
|
|
# Fill in inputs with appropriate saved tensors.
|
|
|
|
# Fill in inputs with appropriate saved tensors.
|
|
|
|
for i, idx in enumerate(tensor_indices):
|
|
|
|
for i, idx in enumerate(tensor_indices):
|
|
|
|
tmp = tensors[i].detach().to(ctx.device)
|
|
|
|
inputs[idx] = tensors[i]
|
|
|
|
tmp.requires_grad = tensors[i].requires_grad
|
|
|
|
|
|
|
|
inputs[idx] = tmp
|
|
|
|
|
|
|
|
detached_inputs = detach_variable(tuple(inputs))
|
|
|
|
detached_inputs = detach_variable(tuple(inputs))
|
|
|
|
if ctx.had_autocast_in_fwd:
|
|
|
|
if ctx.had_autocast_in_fwd:
|
|
|
|
with torch.enable_grad(), torch.cuda.amp.autocast():
|
|
|
|
with torch.enable_grad(), torch.cuda.amp.autocast():
|
|
|
|