From de4645046121d6cb89c91566097d6580f5750a03 Mon Sep 17 00:00:00 2001 From: LuGY <74758262+Gy-Lu@users.noreply.github.com> Date: Fri, 11 Mar 2022 10:08:10 +0800 Subject: [PATCH] Added activation offload (#331) * Added activation offload * Fixed the import bug, used the pytest --- colossalai/nn/layer/utils/common.py | 5 ++-- colossalai/utils/__init__.py | 5 ++-- colossalai/utils/activation_checkpoint.py | 26 ++++++++++++------- .../test_activation_checkpointing.py | 14 +++++----- 4 files changed, 28 insertions(+), 22 deletions(-) diff --git a/colossalai/nn/layer/utils/common.py b/colossalai/nn/layer/utils/common.py index 2ec7d19dd..8aa5473d8 100644 --- a/colossalai/nn/layer/utils/common.py +++ b/colossalai/nn/layer/utils/common.py @@ -13,17 +13,18 @@ from torch import Tensor, nn class CheckpointModule(nn.Module): - def __init__(self, checkpoint: bool = True): + def __init__(self, checkpoint: bool = True, offload : bool = False): super().__init__() self.checkpoint = checkpoint self._use_checkpoint = checkpoint + self._offload = offload def _forward(self, *args, **kwargs): raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward') def forward(self, *args, **kwargs): if self._use_checkpoint: - return checkpoint(self._forward, *args, **kwargs) + return checkpoint(self._forward, self._offload, *args, **kwargs) else: return self._forward(*args, **kwargs) diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 848937d7c..b8536a1d5 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -1,3 +1,4 @@ +from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize from .activation_checkpoint import checkpoint from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32, @@ -5,11 +6,11 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral is_no_pp_or_last_stage, is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param) -from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize from .data_sampler import DataParallelSampler, get_dataloader from .gradient_accumulation import accumulate_gradient from .memory import report_memory_usage from .timer import MultiTimer, Timer +#from .tensor_detector import TensorDetector __all__ = [ 'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0', @@ -17,5 +18,5 @@ __all__ = [ 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', - 'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'is_moe_parallel_parameter' + 'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'is_moe_parallel_parameter', 'TensorDetector' ] diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/utils/activation_checkpoint.py index f50211614..fdac97977 100644 --- a/colossalai/utils/activation_checkpoint.py +++ b/colossalai/utils/activation_checkpoint.py @@ -5,14 +5,16 @@ import torch from torch.utils.checkpoint import check_backward_validity, detach_variable from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states - +from .cuda import get_current_device class CheckpointFunction(torch.autograd.Function): @staticmethod - def forward(ctx, run_function, *args): + def forward(ctx, run_function, activation_offload=False, *args): check_backward_validity(args) ctx.run_function = run_function + ctx.activation_offload = activation_offload + ctx.device = get_current_device() # preserve rng states ctx.fwd_cpu_rng_state = torch.get_rng_state() @@ -32,7 +34,12 @@ class CheckpointFunction(torch.autograd.Function): tensor_inputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): - tensor_inputs.append(arg) + if ctx.activation_offload: + tmp = arg.detach().cpu() + tmp.requires_grad = arg.requires_grad + tensor_inputs.append(tmp) + else: + tensor_inputs.append(arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) else: @@ -70,8 +77,9 @@ class CheckpointFunction(torch.autograd.Function): # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): - inputs[idx] = tensors[i] - + tmp = tensors[i].detach().to(ctx.device) + tmp.requires_grad = tensors[i].requires_grad + inputs[idx] = tmp detached_inputs = detach_variable(tuple(inputs)) if ctx.had_autocast_in_fwd: with torch.enable_grad(), torch.cuda.amp.autocast(): @@ -82,7 +90,6 @@ class CheckpointFunction(torch.autograd.Function): if isinstance(outputs, torch.Tensor): outputs = (outputs,) - # recover the rng states torch.set_rng_state(bwd_cpu_rng_state) for parallel_mode, state in bwd_seed_states.items(): @@ -103,15 +110,14 @@ class CheckpointFunction(torch.autograd.Function): torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) - - return (None,) + grads + return (None, None) + grads -def checkpoint(function, *args): +def checkpoint(function, activation_offload ,*args): """Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint :param function: Describe the forward pass function. It should know how to handle the input tuples. :param args: Tuple containing the parameters of the function :return: Output of running function with provided args """ - return CheckpointFunction.apply(function, *args) + return CheckpointFunction.apply(function, activation_offload, *args) diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index 1adc548fb..f4127228d 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -17,13 +17,14 @@ def forward(x, weight): out_ = F.dropout(out, p=0.4, training=True) return out_ - @pytest.mark.gpu -def test_activation_checkpointing(): - add_seed(ParallelMode.GLOBAL, 1024) +@pytest.mark.parametrize("cpu_offload", [True, False]) +def test_activation_checkpointing(cpu_offload): + if cpu_offload: + add_seed(ParallelMode.GLOBAL, 1024) + add_seed(ParallelMode.DATA, 1026) set_mode(ParallelMode.GLOBAL) global_cuda_rng_state = torch.cuda.get_rng_state() - add_seed(ParallelMode.DATA, 1026) set_mode(ParallelMode.DATA) data_parallel_cuda_rng_state = torch.cuda.get_rng_state() set_mode(ParallelMode.GLOBAL) @@ -49,13 +50,10 @@ def test_activation_checkpointing(): set_mode(ParallelMode.DATA) torch.cuda.set_rng_state(data_parallel_cuda_rng_state) set_mode(ParallelMode.GLOBAL) - out = checkpoint(forward, data_, weight_) + out = checkpoint(forward, cpu_offload, data_, weight_) loss = out.sum() loss.backward() assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match' torch.cuda.empty_cache() - -if __name__ == '__main__': - test_activation_checkpointing()