|
|
|
@ -7,6 +7,19 @@ from torch.utils.checkpoint import check_backward_validity, detach_variable
|
|
|
|
|
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states |
|
|
|
|
from .cuda import get_current_device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def copy_to_device(obj, device): |
|
|
|
|
if torch.is_tensor(obj): |
|
|
|
|
return obj.to(device) |
|
|
|
|
elif isinstance(obj, list): |
|
|
|
|
return [copy_to_device(i, device) for i in obj] |
|
|
|
|
elif isinstance(obj, tuple): |
|
|
|
|
return tuple([copy_to_device(v, device) for v in obj]) |
|
|
|
|
elif isinstance(obj, dict): |
|
|
|
|
return {k: copy_to_device(v, device) for k, v in obj.items()} |
|
|
|
|
else: |
|
|
|
|
return obj |
|
|
|
|
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function): |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
@ -27,6 +40,13 @@ class CheckpointFunction(torch.autograd.Function):
|
|
|
|
|
else: |
|
|
|
|
ctx.had_autocast_in_fwd = False |
|
|
|
|
|
|
|
|
|
if activation_offload: |
|
|
|
|
inputs_cuda = copy_to_device(args, ctx.device) |
|
|
|
|
else: |
|
|
|
|
inputs_cuda = args |
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
outputs = run_function(*inputs_cuda) |
|
|
|
|
# Save non-tensor inputs in ctx, keep a placeholder None for tensors |
|
|
|
|
# to be filled out during the backward. |
|
|
|
|
ctx.inputs = [] |
|
|
|
@ -34,10 +54,8 @@ class CheckpointFunction(torch.autograd.Function):
|
|
|
|
|
tensor_inputs = [] |
|
|
|
|
for i, arg in enumerate(args): |
|
|
|
|
if torch.is_tensor(arg): |
|
|
|
|
if ctx.activation_offload: |
|
|
|
|
tmp = arg.detach().cpu() |
|
|
|
|
tmp.requires_grad = arg.requires_grad |
|
|
|
|
tensor_inputs.append(tmp) |
|
|
|
|
if activation_offload: |
|
|
|
|
tensor_inputs.append(copy_to_device(arg, 'cpu')) |
|
|
|
|
else: |
|
|
|
|
tensor_inputs.append(arg) |
|
|
|
|
ctx.tensor_indices.append(i) |
|
|
|
@ -46,18 +64,15 @@ class CheckpointFunction(torch.autograd.Function):
|
|
|
|
|
ctx.inputs.append(arg) |
|
|
|
|
|
|
|
|
|
ctx.save_for_backward(*tensor_inputs) |
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
outputs = run_function(*args) |
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, *args): |
|
|
|
|
if not torch.autograd._is_checkpoint_valid(): |
|
|
|
|
raise RuntimeError( |
|
|
|
|
"Checkpointing is not compatible with .grad() or when an `inputs` parameter" |
|
|
|
|
" is passed to .backward(). Please use .backward() and do not pass its `inputs`" |
|
|
|
|
" argument.") |
|
|
|
|
"Checkpointing is not compatible with .grad() or when an `inputs` parameter is " |
|
|
|
|
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument." |
|
|
|
|
) |
|
|
|
|
# Copy the list to avoid modifying original list. |
|
|
|
|
inputs = list(ctx.inputs) |
|
|
|
|
tensor_indices = ctx.tensor_indices |
|
|
|
@ -74,12 +89,12 @@ class CheckpointFunction(torch.autograd.Function):
|
|
|
|
|
for parallel_mode, state in ctx.fwd_seed_states.items(): |
|
|
|
|
set_seed_states(parallel_mode, state) |
|
|
|
|
set_mode(ctx.fwd_current_mode) |
|
|
|
|
if ctx.activation_offload: |
|
|
|
|
tensors = copy_to_device(tensors, ctx.device) |
|
|
|
|
|
|
|
|
|
# Fill in inputs with appropriate saved tensors. |
|
|
|
|
for i, idx in enumerate(tensor_indices): |
|
|
|
|
tmp = tensors[i].detach().to(ctx.device) |
|
|
|
|
tmp.requires_grad = tensors[i].requires_grad |
|
|
|
|
inputs[idx] = tmp |
|
|
|
|
inputs[idx] = tensors[i] |
|
|
|
|
detached_inputs = detach_variable(tuple(inputs)) |
|
|
|
|
if ctx.had_autocast_in_fwd: |
|
|
|
|
with torch.enable_grad(), torch.cuda.amp.autocast(): |
|
|
|
|