|
|
|
@ -5,14 +5,16 @@ import torch
|
|
|
|
|
from torch.utils.checkpoint import check_backward_validity, detach_variable |
|
|
|
|
|
|
|
|
|
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states |
|
|
|
|
|
|
|
|
|
from .cuda import get_current_device |
|
|
|
|
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function): |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, run_function, *args): |
|
|
|
|
def forward(ctx, run_function, activation_offload=False, *args): |
|
|
|
|
check_backward_validity(args) |
|
|
|
|
ctx.run_function = run_function |
|
|
|
|
ctx.activation_offload = activation_offload |
|
|
|
|
ctx.device = get_current_device() |
|
|
|
|
|
|
|
|
|
# preserve rng states |
|
|
|
|
ctx.fwd_cpu_rng_state = torch.get_rng_state() |
|
|
|
@ -32,7 +34,12 @@ class CheckpointFunction(torch.autograd.Function):
|
|
|
|
|
tensor_inputs = [] |
|
|
|
|
for i, arg in enumerate(args): |
|
|
|
|
if torch.is_tensor(arg): |
|
|
|
|
tensor_inputs.append(arg) |
|
|
|
|
if ctx.activation_offload: |
|
|
|
|
tmp = arg.detach().cpu() |
|
|
|
|
tmp.requires_grad = arg.requires_grad |
|
|
|
|
tensor_inputs.append(tmp) |
|
|
|
|
else: |
|
|
|
|
tensor_inputs.append(arg) |
|
|
|
|
ctx.tensor_indices.append(i) |
|
|
|
|
ctx.inputs.append(None) |
|
|
|
|
else: |
|
|
|
@ -70,8 +77,9 @@ class CheckpointFunction(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
# Fill in inputs with appropriate saved tensors. |
|
|
|
|
for i, idx in enumerate(tensor_indices): |
|
|
|
|
inputs[idx] = tensors[i] |
|
|
|
|
|
|
|
|
|
tmp = tensors[i].detach().to(ctx.device) |
|
|
|
|
tmp.requires_grad = tensors[i].requires_grad |
|
|
|
|
inputs[idx] = tmp |
|
|
|
|
detached_inputs = detach_variable(tuple(inputs)) |
|
|
|
|
if ctx.had_autocast_in_fwd: |
|
|
|
|
with torch.enable_grad(), torch.cuda.amp.autocast(): |
|
|
|
@ -82,7 +90,6 @@ class CheckpointFunction(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
if isinstance(outputs, torch.Tensor): |
|
|
|
|
outputs = (outputs,) |
|
|
|
|
|
|
|
|
|
# recover the rng states |
|
|
|
|
torch.set_rng_state(bwd_cpu_rng_state) |
|
|
|
|
for parallel_mode, state in bwd_seed_states.items(): |
|
|
|
@ -103,15 +110,14 @@ class CheckpointFunction(torch.autograd.Function):
|
|
|
|
|
torch.autograd.backward(outputs_with_grad, args_with_grad) |
|
|
|
|
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None |
|
|
|
|
for inp in detached_inputs) |
|
|
|
|
|
|
|
|
|
return (None,) + grads |
|
|
|
|
return (None, None) + grads |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint(function, *args): |
|
|
|
|
def checkpoint(function, activation_offload ,*args): |
|
|
|
|
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint |
|
|
|
|
|
|
|
|
|
:param function: Describe the forward pass function. It should know how to handle the input tuples. |
|
|
|
|
:param args: Tuple containing the parameters of the function |
|
|
|
|
:return: Output of running function with provided args |
|
|
|
|
""" |
|
|
|
|
return CheckpointFunction.apply(function, *args) |
|
|
|
|
return CheckpointFunction.apply(function, activation_offload, *args) |
|
|
|
|