mirror of https://github.com/hpcaitech/ColossalAI
[utils] Add use_reetrant=False in utils.activation_checkpoint (#1460)
* [utils] Add use_reetrant=False into colossalai checkpoint * [utils] add some annotation in utils.activaion_checkpoint * [test] add reset_seed at the beginning of tests in test_actiavion_checkpointing.py * [test] modify test_activation_checkpoint.py * [test] modify test for reentrant=Falsepull/1463/head
parent
36824a304c
commit
47fd8e4a02
|
@ -7,6 +7,8 @@ from torch.utils.checkpoint import check_backward_validity, detach_variable
|
||||||
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
|
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
|
||||||
from .cuda import get_current_device
|
from .cuda import get_current_device
|
||||||
|
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
|
||||||
def copy_to_device(obj, device):
|
def copy_to_device(obj, device):
|
||||||
if torch.is_tensor(obj):
|
if torch.is_tensor(obj):
|
||||||
|
@ -136,14 +138,122 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
return (None, None) + grads
|
return (None, None) + grads
|
||||||
|
|
||||||
|
|
||||||
def checkpoint(function, activation_offload, *args):
|
def checkpoint(function, activation_offload, *args, use_reentrant: bool = True):
|
||||||
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint.
|
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
function: Describe the forward pass function. It should know how to handle the input tuples.
|
function: Describe the forward pass function. It should know how to handle the input tuples.
|
||||||
|
activation_offload: The variable to check whether we should offload activation to cpu
|
||||||
args (list): Tuple containing the parameters of the function
|
args (list): Tuple containing the parameters of the function
|
||||||
|
use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there
|
||||||
|
might be more flexibility for user to define there checkpoint function
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output of running function with provided args.
|
Output of running function with provided args.
|
||||||
"""
|
"""
|
||||||
return CheckpointFunction.apply(function, activation_offload, *args)
|
if use_reentrant:
|
||||||
|
return CheckpointFunction.apply(function, activation_offload, *args)
|
||||||
|
else:
|
||||||
|
return _checkpoint_without_reentrant(
|
||||||
|
function,
|
||||||
|
activation_offload,
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _checkpoint_without_reentrant(function, activation_offload=False, *args):
|
||||||
|
# store rng_state
|
||||||
|
fwd_cpu_state = torch.get_rng_state()
|
||||||
|
sync_states()
|
||||||
|
fwd_seed_states = get_states(copy=True)
|
||||||
|
fwd_current_mode = get_current_mode()
|
||||||
|
|
||||||
|
# check if use autocast
|
||||||
|
if hasattr(torch, 'is_autocast_enabled'):
|
||||||
|
has_autocast_in_fwd = torch.is_autocast_enabled()
|
||||||
|
else:
|
||||||
|
has_autocast_in_fwd = False
|
||||||
|
|
||||||
|
# using WeakKeyDictionary to store all the activation the first time we call unpack
|
||||||
|
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
||||||
|
weak_holder_list = []
|
||||||
|
|
||||||
|
# class for weakref.ref
|
||||||
|
class Holder():
|
||||||
|
pass
|
||||||
|
|
||||||
|
# return a Holder object for later unpack process
|
||||||
|
def pack(x):
|
||||||
|
res = Holder()
|
||||||
|
weak_holder_list.append(weakref.ref(res))
|
||||||
|
return res
|
||||||
|
|
||||||
|
# unpack hook
|
||||||
|
def unpack(x):
|
||||||
|
unpack_counter = 0
|
||||||
|
|
||||||
|
# re-compute all the activation inside the function when we first call unpack
|
||||||
|
if len(storage) == 0:
|
||||||
|
|
||||||
|
def inner_pack(inner):
|
||||||
|
nonlocal unpack_counter
|
||||||
|
unpack_counter += 1
|
||||||
|
|
||||||
|
# If the holder went out of scope, the SavedVariable is dead and so
|
||||||
|
# the value will never be read from the storage. Skip filling it.
|
||||||
|
if weak_holder_list[unpack_counter - 1]() is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use detach here to ensure we don't keep the temporary autograd
|
||||||
|
# graph created during the second forward
|
||||||
|
storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
|
||||||
|
return
|
||||||
|
|
||||||
|
def inner_unpack(packed):
|
||||||
|
raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
|
||||||
|
|
||||||
|
# restore rng state
|
||||||
|
torch.set_rng_state(fwd_cpu_state)
|
||||||
|
for parallel_mode, state in fwd_seed_states.items():
|
||||||
|
set_seed_states(parallel_mode, state)
|
||||||
|
set_mode(fwd_current_mode)
|
||||||
|
|
||||||
|
# reload arg into device if needed
|
||||||
|
if activation_offload:
|
||||||
|
for arg in args:
|
||||||
|
if torch.is_tensor(arg):
|
||||||
|
arg = arg.to(device=device)
|
||||||
|
|
||||||
|
# rerun forward, the inner_pack will store all the activations in storage
|
||||||
|
if has_autocast_in_fwd:
|
||||||
|
with torch.enable_grad(), \
|
||||||
|
torch.cuda.amp.autocast(), \
|
||||||
|
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||||
|
_unused = function(*args)
|
||||||
|
else:
|
||||||
|
with torch.enable_grad(), \
|
||||||
|
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||||
|
_unused = function(*args)
|
||||||
|
|
||||||
|
if x not in storage:
|
||||||
|
raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
|
||||||
|
" recomputation being triggered in between, this is not currently supported. Please"
|
||||||
|
" open an issue with details on your use case so that we can prioritize adding this.")
|
||||||
|
|
||||||
|
return storage[x]
|
||||||
|
|
||||||
|
# get device if we need to offload the activation
|
||||||
|
if activation_offload:
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
|
# run function with pack and unpack as saved_tensors_hooks
|
||||||
|
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||||
|
output = function(*args)
|
||||||
|
|
||||||
|
# offload activation if needed
|
||||||
|
if activation_offload:
|
||||||
|
for arg in args:
|
||||||
|
if torch.is_tensor(arg):
|
||||||
|
arg = arg.to(device="cpu")
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.context.random import add_seed, seed, set_mode, reset_seeds
|
from colossalai.context.random import add_seed, seed, set_mode, reset_seeds
|
||||||
from colossalai.utils import checkpoint
|
from colossalai.utils.activation_checkpoint import checkpoint
|
||||||
|
|
||||||
|
|
||||||
def forward(x, weight):
|
def forward(x, weight):
|
||||||
|
@ -16,10 +16,37 @@ def forward(x, weight):
|
||||||
return out_
|
return out_
|
||||||
|
|
||||||
|
|
||||||
|
def forward_inplace_ckpt(x, weight, cpu_offload=False):
|
||||||
|
out = torch.matmul(x, weight)
|
||||||
|
bn = torch.nn.BatchNorm1d(4, affine=False)
|
||||||
|
bn = bn.to(device="cuda")
|
||||||
|
out = bn(out)
|
||||||
|
|
||||||
|
def ckpt0(x):
|
||||||
|
return F.relu(x, inplace=True)
|
||||||
|
|
||||||
|
out = checkpoint(ckpt0, cpu_offload, out, use_reentrant=False)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def forward_inplace(x, weight):
|
||||||
|
out = torch.matmul(x, weight)
|
||||||
|
bn = torch.nn.BatchNorm1d(4, affine=False)
|
||||||
|
bn = bn.to(device="cuda")
|
||||||
|
out = bn(out)
|
||||||
|
out = F.relu(out, inplace=True)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.gpu
|
@pytest.mark.gpu
|
||||||
@pytest.mark.skip("set seed error")
|
@pytest.mark.parametrize("use_reentrant", [True, False])
|
||||||
@pytest.mark.parametrize("cpu_offload", [True, False])
|
@pytest.mark.parametrize("cpu_offload", [True, False])
|
||||||
def test_activation_checkpointing(cpu_offload):
|
def test_activation_checkpointing(cpu_offload, use_reentrant):
|
||||||
|
|
||||||
|
# as seed manager is singleton
|
||||||
|
# if we don't reset seeds here,
|
||||||
|
# other tests might affect this test
|
||||||
|
reset_seeds()
|
||||||
|
|
||||||
# We put initilization here to avoid change cuda rng state below
|
# We put initilization here to avoid change cuda rng state below
|
||||||
inputs = torch.rand(2, 2, requires_grad=True, device='cuda')
|
inputs = torch.rand(2, 2, requires_grad=True, device='cuda')
|
||||||
|
@ -50,15 +77,46 @@ def test_activation_checkpointing(cpu_offload):
|
||||||
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||||
set_mode(ParallelMode.GLOBAL)
|
set_mode(ParallelMode.GLOBAL)
|
||||||
|
|
||||||
out = checkpoint(forward, cpu_offload, inputs_, weight_)
|
out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=use_reentrant)
|
||||||
loss = out.sum()
|
loss = out.sum()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
|
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Extra test for use_reentrant=False
|
||||||
|
if use_reentrant == False:
|
||||||
|
# Recover cuda rng states
|
||||||
|
set_mode(ParallelMode.GLOBAL)
|
||||||
|
torch.cuda.set_rng_state(global_cuda_rng_state)
|
||||||
|
set_mode(ParallelMode.DATA)
|
||||||
|
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||||
|
set_mode(ParallelMode.GLOBAL)
|
||||||
|
|
||||||
|
out = forward_inplace(inputs, weight)
|
||||||
|
loss = out.sum()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Recover cuda rng states
|
||||||
|
set_mode(ParallelMode.GLOBAL)
|
||||||
|
torch.cuda.set_rng_state(global_cuda_rng_state)
|
||||||
|
set_mode(ParallelMode.DATA)
|
||||||
|
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||||
|
set_mode(ParallelMode.GLOBAL)
|
||||||
|
|
||||||
|
out = forward_inplace_ckpt(inputs_, weight_, cpu_offload=cpu_offload)
|
||||||
|
loss = out.sum()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# as seed manager is singleton
|
# as seed manager is singleton
|
||||||
# if we don't reset seeds here,
|
# if we don't reset seeds here,
|
||||||
# other tests will fail if running together with this test
|
# other tests will fail if running together with this test
|
||||||
# as other tests can't overwrite the seed set by this test
|
# as other tests can't overwrite the seed set by this test
|
||||||
reset_seeds()
|
reset_seeds()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_activation_checkpointing(False, False)
|
||||||
|
|
Loading…
Reference in New Issue