mirror of https://github.com/hpcaitech/ColossalAI
[utils] Add use_reetrant=False in utils.activation_checkpoint (#1460)
* [utils] Add use_reetrant=False into colossalai checkpoint * [utils] add some annotation in utils.activaion_checkpoint * [test] add reset_seed at the beginning of tests in test_actiavion_checkpointing.py * [test] modify test_activation_checkpoint.py * [test] modify test for reentrant=Falsepull/1463/head
parent
36824a304c
commit
47fd8e4a02
|
@ -7,6 +7,8 @@ from torch.utils.checkpoint import check_backward_validity, detach_variable
|
|||
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
|
||||
from .cuda import get_current_device
|
||||
|
||||
import weakref
|
||||
|
||||
|
||||
def copy_to_device(obj, device):
|
||||
if torch.is_tensor(obj):
|
||||
|
@ -136,14 +138,122 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
return (None, None) + grads
|
||||
|
||||
|
||||
def checkpoint(function, activation_offload, *args):
|
||||
def checkpoint(function, activation_offload, *args, use_reentrant: bool = True):
|
||||
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint.
|
||||
|
||||
Args:
|
||||
function: Describe the forward pass function. It should know how to handle the input tuples.
|
||||
activation_offload: The variable to check whether we should offload activation to cpu
|
||||
args (list): Tuple containing the parameters of the function
|
||||
use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there
|
||||
might be more flexibility for user to define there checkpoint function
|
||||
|
||||
Returns:
|
||||
Output of running function with provided args.
|
||||
"""
|
||||
return CheckpointFunction.apply(function, activation_offload, *args)
|
||||
if use_reentrant:
|
||||
return CheckpointFunction.apply(function, activation_offload, *args)
|
||||
else:
|
||||
return _checkpoint_without_reentrant(
|
||||
function,
|
||||
activation_offload,
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def _checkpoint_without_reentrant(function, activation_offload=False, *args):
|
||||
# store rng_state
|
||||
fwd_cpu_state = torch.get_rng_state()
|
||||
sync_states()
|
||||
fwd_seed_states = get_states(copy=True)
|
||||
fwd_current_mode = get_current_mode()
|
||||
|
||||
# check if use autocast
|
||||
if hasattr(torch, 'is_autocast_enabled'):
|
||||
has_autocast_in_fwd = torch.is_autocast_enabled()
|
||||
else:
|
||||
has_autocast_in_fwd = False
|
||||
|
||||
# using WeakKeyDictionary to store all the activation the first time we call unpack
|
||||
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
||||
weak_holder_list = []
|
||||
|
||||
# class for weakref.ref
|
||||
class Holder():
|
||||
pass
|
||||
|
||||
# return a Holder object for later unpack process
|
||||
def pack(x):
|
||||
res = Holder()
|
||||
weak_holder_list.append(weakref.ref(res))
|
||||
return res
|
||||
|
||||
# unpack hook
|
||||
def unpack(x):
|
||||
unpack_counter = 0
|
||||
|
||||
# re-compute all the activation inside the function when we first call unpack
|
||||
if len(storage) == 0:
|
||||
|
||||
def inner_pack(inner):
|
||||
nonlocal unpack_counter
|
||||
unpack_counter += 1
|
||||
|
||||
# If the holder went out of scope, the SavedVariable is dead and so
|
||||
# the value will never be read from the storage. Skip filling it.
|
||||
if weak_holder_list[unpack_counter - 1]() is None:
|
||||
return
|
||||
|
||||
# Use detach here to ensure we don't keep the temporary autograd
|
||||
# graph created during the second forward
|
||||
storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
|
||||
return
|
||||
|
||||
def inner_unpack(packed):
|
||||
raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
|
||||
|
||||
# restore rng state
|
||||
torch.set_rng_state(fwd_cpu_state)
|
||||
for parallel_mode, state in fwd_seed_states.items():
|
||||
set_seed_states(parallel_mode, state)
|
||||
set_mode(fwd_current_mode)
|
||||
|
||||
# reload arg into device if needed
|
||||
if activation_offload:
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg):
|
||||
arg = arg.to(device=device)
|
||||
|
||||
# rerun forward, the inner_pack will store all the activations in storage
|
||||
if has_autocast_in_fwd:
|
||||
with torch.enable_grad(), \
|
||||
torch.cuda.amp.autocast(), \
|
||||
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||
_unused = function(*args)
|
||||
else:
|
||||
with torch.enable_grad(), \
|
||||
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||
_unused = function(*args)
|
||||
|
||||
if x not in storage:
|
||||
raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
|
||||
" recomputation being triggered in between, this is not currently supported. Please"
|
||||
" open an issue with details on your use case so that we can prioritize adding this.")
|
||||
|
||||
return storage[x]
|
||||
|
||||
# get device if we need to offload the activation
|
||||
if activation_offload:
|
||||
device = get_current_device()
|
||||
|
||||
# run function with pack and unpack as saved_tensors_hooks
|
||||
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||
output = function(*args)
|
||||
|
||||
# offload activation if needed
|
||||
if activation_offload:
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg):
|
||||
arg = arg.to(device="cpu")
|
||||
|
||||
return output
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.context.random import add_seed, seed, set_mode, reset_seeds
|
||||
from colossalai.utils import checkpoint
|
||||
from colossalai.utils.activation_checkpoint import checkpoint
|
||||
|
||||
|
||||
def forward(x, weight):
|
||||
|
@ -16,10 +16,37 @@ def forward(x, weight):
|
|||
return out_
|
||||
|
||||
|
||||
def forward_inplace_ckpt(x, weight, cpu_offload=False):
|
||||
out = torch.matmul(x, weight)
|
||||
bn = torch.nn.BatchNorm1d(4, affine=False)
|
||||
bn = bn.to(device="cuda")
|
||||
out = bn(out)
|
||||
|
||||
def ckpt0(x):
|
||||
return F.relu(x, inplace=True)
|
||||
|
||||
out = checkpoint(ckpt0, cpu_offload, out, use_reentrant=False)
|
||||
return out
|
||||
|
||||
|
||||
def forward_inplace(x, weight):
|
||||
out = torch.matmul(x, weight)
|
||||
bn = torch.nn.BatchNorm1d(4, affine=False)
|
||||
bn = bn.to(device="cuda")
|
||||
out = bn(out)
|
||||
out = F.relu(out, inplace=True)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.skip("set seed error")
|
||||
@pytest.mark.parametrize("use_reentrant", [True, False])
|
||||
@pytest.mark.parametrize("cpu_offload", [True, False])
|
||||
def test_activation_checkpointing(cpu_offload):
|
||||
def test_activation_checkpointing(cpu_offload, use_reentrant):
|
||||
|
||||
# as seed manager is singleton
|
||||
# if we don't reset seeds here,
|
||||
# other tests might affect this test
|
||||
reset_seeds()
|
||||
|
||||
# We put initilization here to avoid change cuda rng state below
|
||||
inputs = torch.rand(2, 2, requires_grad=True, device='cuda')
|
||||
|
@ -50,15 +77,46 @@ def test_activation_checkpointing(cpu_offload):
|
|||
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
|
||||
out = checkpoint(forward, cpu_offload, inputs_, weight_)
|
||||
out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=use_reentrant)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Extra test for use_reentrant=False
|
||||
if use_reentrant == False:
|
||||
# Recover cuda rng states
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
torch.cuda.set_rng_state(global_cuda_rng_state)
|
||||
set_mode(ParallelMode.DATA)
|
||||
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
|
||||
out = forward_inplace(inputs, weight)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
# Recover cuda rng states
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
torch.cuda.set_rng_state(global_cuda_rng_state)
|
||||
set_mode(ParallelMode.DATA)
|
||||
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
|
||||
out = forward_inplace_ckpt(inputs_, weight_, cpu_offload=cpu_offload)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# as seed manager is singleton
|
||||
# if we don't reset seeds here,
|
||||
# other tests will fail if running together with this test
|
||||
# as other tests can't overwrite the seed set by this test
|
||||
reset_seeds()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_activation_checkpointing(False, False)
|
||||
|
|
Loading…
Reference in New Issue