mirror of https://github.com/hpcaitech/ColossalAI
Added activation offload (#331)
* Added activation offload * Fixed the import bug, used the pytestpull/394/head
parent
272ebfb57d
commit
de46450461
|
@ -13,17 +13,18 @@ from torch import Tensor, nn
|
|||
|
||||
|
||||
class CheckpointModule(nn.Module):
|
||||
def __init__(self, checkpoint: bool = True):
|
||||
def __init__(self, checkpoint: bool = True, offload : bool = False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self._use_checkpoint = checkpoint
|
||||
self._offload = offload
|
||||
|
||||
def _forward(self, *args, **kwargs):
|
||||
raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward')
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._use_checkpoint:
|
||||
return checkpoint(self._forward, *args, **kwargs)
|
||||
return checkpoint(self._forward, self._offload, *args, **kwargs)
|
||||
else:
|
||||
return self._forward(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
||||
from .activation_checkpoint import checkpoint
|
||||
|
||||
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
|
||||
|
@ -5,11 +6,11 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
|
|||
is_no_pp_or_last_stage, is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence,
|
||||
multi_tensor_applier, param_is_not_tensor_parallel_duplicate, print_rank_0,
|
||||
switch_virtual_pipeline_parallel_rank, sync_model_param)
|
||||
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
||||
from .data_sampler import DataParallelSampler, get_dataloader
|
||||
from .gradient_accumulation import accumulate_gradient
|
||||
from .memory import report_memory_usage
|
||||
from .timer import MultiTimer, Timer
|
||||
#from .tensor_detector import TensorDetector
|
||||
|
||||
__all__ = [
|
||||
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
|
||||
|
@ -17,5 +18,5 @@ __all__ = [
|
|||
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
||||
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
||||
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'is_moe_parallel_parameter'
|
||||
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'is_moe_parallel_parameter', 'TensorDetector'
|
||||
]
|
||||
|
|
|
@ -5,14 +5,16 @@ import torch
|
|||
from torch.utils.checkpoint import check_backward_validity, detach_variable
|
||||
|
||||
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
|
||||
|
||||
from .cuda import get_current_device
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, *args):
|
||||
def forward(ctx, run_function, activation_offload=False, *args):
|
||||
check_backward_validity(args)
|
||||
ctx.run_function = run_function
|
||||
ctx.activation_offload = activation_offload
|
||||
ctx.device = get_current_device()
|
||||
|
||||
# preserve rng states
|
||||
ctx.fwd_cpu_rng_state = torch.get_rng_state()
|
||||
|
@ -32,7 +34,12 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
tensor_inputs = []
|
||||
for i, arg in enumerate(args):
|
||||
if torch.is_tensor(arg):
|
||||
tensor_inputs.append(arg)
|
||||
if ctx.activation_offload:
|
||||
tmp = arg.detach().cpu()
|
||||
tmp.requires_grad = arg.requires_grad
|
||||
tensor_inputs.append(tmp)
|
||||
else:
|
||||
tensor_inputs.append(arg)
|
||||
ctx.tensor_indices.append(i)
|
||||
ctx.inputs.append(None)
|
||||
else:
|
||||
|
@ -70,8 +77,9 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
|
||||
# Fill in inputs with appropriate saved tensors.
|
||||
for i, idx in enumerate(tensor_indices):
|
||||
inputs[idx] = tensors[i]
|
||||
|
||||
tmp = tensors[i].detach().to(ctx.device)
|
||||
tmp.requires_grad = tensors[i].requires_grad
|
||||
inputs[idx] = tmp
|
||||
detached_inputs = detach_variable(tuple(inputs))
|
||||
if ctx.had_autocast_in_fwd:
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast():
|
||||
|
@ -82,7 +90,6 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
|
||||
# recover the rng states
|
||||
torch.set_rng_state(bwd_cpu_rng_state)
|
||||
for parallel_mode, state in bwd_seed_states.items():
|
||||
|
@ -103,15 +110,14 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
|
||||
for inp in detached_inputs)
|
||||
|
||||
return (None,) + grads
|
||||
return (None, None) + grads
|
||||
|
||||
|
||||
def checkpoint(function, *args):
|
||||
def checkpoint(function, activation_offload ,*args):
|
||||
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint
|
||||
|
||||
:param function: Describe the forward pass function. It should know how to handle the input tuples.
|
||||
:param args: Tuple containing the parameters of the function
|
||||
:return: Output of running function with provided args
|
||||
"""
|
||||
return CheckpointFunction.apply(function, *args)
|
||||
return CheckpointFunction.apply(function, activation_offload, *args)
|
||||
|
|
|
@ -17,13 +17,14 @@ def forward(x, weight):
|
|||
out_ = F.dropout(out, p=0.4, training=True)
|
||||
return out_
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_activation_checkpointing():
|
||||
add_seed(ParallelMode.GLOBAL, 1024)
|
||||
@pytest.mark.parametrize("cpu_offload", [True, False])
|
||||
def test_activation_checkpointing(cpu_offload):
|
||||
if cpu_offload:
|
||||
add_seed(ParallelMode.GLOBAL, 1024)
|
||||
add_seed(ParallelMode.DATA, 1026)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
global_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
add_seed(ParallelMode.DATA, 1026)
|
||||
set_mode(ParallelMode.DATA)
|
||||
data_parallel_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
|
@ -49,13 +50,10 @@ def test_activation_checkpointing():
|
|||
set_mode(ParallelMode.DATA)
|
||||
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
out = checkpoint(forward, data_, weight_)
|
||||
out = checkpoint(forward, cpu_offload, data_, weight_)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_activation_checkpointing()
|
||||
|
|
Loading…
Reference in New Issue