mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
117 lines
4.3 KiB
117 lines
4.3 KiB
#!/usr/bin/env python |
|
# -*- encoding: utf-8 -*- |
|
|
|
import torch |
|
from torch.utils.checkpoint import check_backward_validity, detach_variable |
|
|
|
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states |
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, run_function, *args): |
|
check_backward_validity(args) |
|
ctx.run_function = run_function |
|
|
|
# preserve rng states |
|
ctx.fwd_cpu_rng_state = torch.get_rng_state() |
|
sync_states() |
|
ctx.fwd_seed_states = get_states(copy=True) |
|
ctx.fwd_current_mode = get_current_mode() |
|
|
|
if hasattr(torch, 'is_autocast_enabled'): |
|
ctx.had_autocast_in_fwd = torch.is_autocast_enabled() |
|
else: |
|
ctx.had_autocast_in_fwd = False |
|
|
|
# Save non-tensor inputs in ctx, keep a placeholder None for tensors |
|
# to be filled out during the backward. |
|
ctx.inputs = [] |
|
ctx.tensor_indices = [] |
|
tensor_inputs = [] |
|
for i, arg in enumerate(args): |
|
if torch.is_tensor(arg): |
|
tensor_inputs.append(arg) |
|
ctx.tensor_indices.append(i) |
|
ctx.inputs.append(None) |
|
else: |
|
ctx.inputs.append(arg) |
|
|
|
ctx.save_for_backward(*tensor_inputs) |
|
|
|
with torch.no_grad(): |
|
outputs = run_function(*args) |
|
return outputs |
|
|
|
@staticmethod |
|
def backward(ctx, *args): |
|
if not torch.autograd._is_checkpoint_valid(): |
|
raise RuntimeError( |
|
"Checkpointing is not compatible with .grad() or when an `inputs` parameter" |
|
" is passed to .backward(). Please use .backward() and do not pass its `inputs`" |
|
" argument.") |
|
# Copy the list to avoid modifying original list. |
|
inputs = list(ctx.inputs) |
|
tensor_indices = ctx.tensor_indices |
|
tensors = ctx.saved_tensors |
|
|
|
# store the current states |
|
bwd_cpu_rng_state = torch.get_rng_state() |
|
sync_states() |
|
bwd_seed_states = get_states(copy=True) |
|
bwd_current_mode = get_current_mode() |
|
|
|
# set the states to what it used to be |
|
torch.set_rng_state(ctx.fwd_cpu_rng_state) |
|
for parallel_mode, state in ctx.fwd_seed_states.items(): |
|
set_seed_states(parallel_mode, state) |
|
set_mode(ctx.fwd_current_mode) |
|
|
|
# Fill in inputs with appropriate saved tensors. |
|
for i, idx in enumerate(tensor_indices): |
|
inputs[idx] = tensors[i] |
|
|
|
detached_inputs = detach_variable(tuple(inputs)) |
|
if ctx.had_autocast_in_fwd: |
|
with torch.enable_grad(), torch.cuda.amp.autocast(): |
|
outputs = ctx.run_function(*detached_inputs) |
|
else: |
|
with torch.enable_grad(): |
|
outputs = ctx.run_function(*detached_inputs) |
|
|
|
if isinstance(outputs, torch.Tensor): |
|
outputs = (outputs,) |
|
|
|
# recover the rng states |
|
torch.set_rng_state(bwd_cpu_rng_state) |
|
for parallel_mode, state in bwd_seed_states.items(): |
|
set_seed_states(parallel_mode, state) |
|
set_mode(bwd_current_mode) |
|
|
|
# run backward() with only tensor that requires grad |
|
outputs_with_grad = [] |
|
args_with_grad = [] |
|
for i in range(len(outputs)): |
|
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: |
|
outputs_with_grad.append(outputs[i]) |
|
args_with_grad.append(args[i]) |
|
if len(outputs_with_grad) == 0: |
|
raise RuntimeError( |
|
"none of output has requires_grad=True," |
|
" this checkpoint() is not necessary") |
|
torch.autograd.backward(outputs_with_grad, args_with_grad) |
|
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None |
|
for inp in detached_inputs) |
|
|
|
return (None,) + grads |
|
|
|
|
|
def checkpoint(function, *args): |
|
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint |
|
|
|
:param function: Describe the forward pass function. It should know how to handle the input tuples. |
|
:param args: Tuple containing the parameters of the function |
|
:return: Output of running function with provided args |
|
""" |
|
return CheckpointFunction.apply(function, *args)
|
|
|