Browse Source

Added activation offload (#331)

* Added activation offload

* Fixed the import bug, used the pytest
pull/394/head
LuGY 3 years ago committed by Frank Lee
parent
commit
de46450461
  1. 5
      colossalai/nn/layer/utils/common.py
  2. 5
      colossalai/utils/__init__.py
  3. 26
      colossalai/utils/activation_checkpoint.py
  4. 14
      tests/test_utils/test_activation_checkpointing.py

5
colossalai/nn/layer/utils/common.py

@ -13,17 +13,18 @@ from torch import Tensor, nn
class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = True):
def __init__(self, checkpoint: bool = True, offload : bool = False):
super().__init__()
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
self._offload = offload
def _forward(self, *args, **kwargs):
raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward')
def forward(self, *args, **kwargs):
if self._use_checkpoint:
return checkpoint(self._forward, *args, **kwargs)
return checkpoint(self._forward, self._offload, *args, **kwargs)
else:
return self._forward(*args, **kwargs)

5
colossalai/utils/__init__.py

@ -1,3 +1,4 @@
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .activation_checkpoint import checkpoint
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
@ -5,11 +6,11 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
is_no_pp_or_last_stage, is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence,
multi_tensor_applier, param_is_not_tensor_parallel_duplicate, print_rank_0,
switch_virtual_pipeline_parallel_rank, sync_model_param)
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient
from .memory import report_memory_usage
from .timer import MultiTimer, Timer
#from .tensor_detector import TensorDetector
__all__ = [
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
@ -17,5 +18,5 @@ __all__ = [
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'is_moe_parallel_parameter'
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'is_moe_parallel_parameter', 'TensorDetector'
]

26
colossalai/utils/activation_checkpoint.py

@ -5,14 +5,16 @@ import torch
from torch.utils.checkpoint import check_backward_validity, detach_variable
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
from .cuda import get_current_device
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, *args):
def forward(ctx, run_function, activation_offload=False, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.activation_offload = activation_offload
ctx.device = get_current_device()
# preserve rng states
ctx.fwd_cpu_rng_state = torch.get_rng_state()
@ -32,7 +34,12 @@ class CheckpointFunction(torch.autograd.Function):
tensor_inputs = []
for i, arg in enumerate(args):
if torch.is_tensor(arg):
tensor_inputs.append(arg)
if ctx.activation_offload:
tmp = arg.detach().cpu()
tmp.requires_grad = arg.requires_grad
tensor_inputs.append(tmp)
else:
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
@ -70,8 +77,9 @@ class CheckpointFunction(torch.autograd.Function):
# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
tmp = tensors[i].detach().to(ctx.device)
tmp.requires_grad = tensors[i].requires_grad
inputs[idx] = tmp
detached_inputs = detach_variable(tuple(inputs))
if ctx.had_autocast_in_fwd:
with torch.enable_grad(), torch.cuda.amp.autocast():
@ -82,7 +90,6 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# recover the rng states
torch.set_rng_state(bwd_cpu_rng_state)
for parallel_mode, state in bwd_seed_states.items():
@ -103,15 +110,14 @@ class CheckpointFunction(torch.autograd.Function):
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs)
return (None,) + grads
return (None, None) + grads
def checkpoint(function, *args):
def checkpoint(function, activation_offload ,*args):
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint
:param function: Describe the forward pass function. It should know how to handle the input tuples.
:param args: Tuple containing the parameters of the function
:return: Output of running function with provided args
"""
return CheckpointFunction.apply(function, *args)
return CheckpointFunction.apply(function, activation_offload, *args)

14
tests/test_utils/test_activation_checkpointing.py

@ -17,13 +17,14 @@ def forward(x, weight):
out_ = F.dropout(out, p=0.4, training=True)
return out_
@pytest.mark.gpu
def test_activation_checkpointing():
add_seed(ParallelMode.GLOBAL, 1024)
@pytest.mark.parametrize("cpu_offload", [True, False])
def test_activation_checkpointing(cpu_offload):
if cpu_offload:
add_seed(ParallelMode.GLOBAL, 1024)
add_seed(ParallelMode.DATA, 1026)
set_mode(ParallelMode.GLOBAL)
global_cuda_rng_state = torch.cuda.get_rng_state()
add_seed(ParallelMode.DATA, 1026)
set_mode(ParallelMode.DATA)
data_parallel_cuda_rng_state = torch.cuda.get_rng_state()
set_mode(ParallelMode.GLOBAL)
@ -49,13 +50,10 @@ def test_activation_checkpointing():
set_mode(ParallelMode.DATA)
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
set_mode(ParallelMode.GLOBAL)
out = checkpoint(forward, data_, weight_)
out = checkpoint(forward, cpu_offload, data_, weight_)
loss = out.sum()
loss.backward()
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
torch.cuda.empty_cache()
if __name__ == '__main__':
test_activation_checkpointing()

Loading…
Cancel
Save