mirror of https://github.com/hpcaitech/ColossalAI
Added activation offload (#331)
* Added activation offload * Fixed the import bug, used the pytestpull/394/head
parent
272ebfb57d
commit
de46450461
|
@ -13,17 +13,18 @@ from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class CheckpointModule(nn.Module):
|
class CheckpointModule(nn.Module):
|
||||||
def __init__(self, checkpoint: bool = True):
|
def __init__(self, checkpoint: bool = True, offload : bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
self._use_checkpoint = checkpoint
|
self._use_checkpoint = checkpoint
|
||||||
|
self._offload = offload
|
||||||
|
|
||||||
def _forward(self, *args, **kwargs):
|
def _forward(self, *args, **kwargs):
|
||||||
raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward')
|
raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward')
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self._use_checkpoint:
|
if self._use_checkpoint:
|
||||||
return checkpoint(self._forward, *args, **kwargs)
|
return checkpoint(self._forward, self._offload, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self._forward(*args, **kwargs)
|
return self._forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
||||||
from .activation_checkpoint import checkpoint
|
from .activation_checkpoint import checkpoint
|
||||||
|
|
||||||
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
|
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
|
||||||
|
@ -5,11 +6,11 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
|
||||||
is_no_pp_or_last_stage, is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence,
|
is_no_pp_or_last_stage, is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence,
|
||||||
multi_tensor_applier, param_is_not_tensor_parallel_duplicate, print_rank_0,
|
multi_tensor_applier, param_is_not_tensor_parallel_duplicate, print_rank_0,
|
||||||
switch_virtual_pipeline_parallel_rank, sync_model_param)
|
switch_virtual_pipeline_parallel_rank, sync_model_param)
|
||||||
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
|
||||||
from .data_sampler import DataParallelSampler, get_dataloader
|
from .data_sampler import DataParallelSampler, get_dataloader
|
||||||
from .gradient_accumulation import accumulate_gradient
|
from .gradient_accumulation import accumulate_gradient
|
||||||
from .memory import report_memory_usage
|
from .memory import report_memory_usage
|
||||||
from .timer import MultiTimer, Timer
|
from .timer import MultiTimer, Timer
|
||||||
|
#from .tensor_detector import TensorDetector
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
|
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
|
||||||
|
@ -17,5 +18,5 @@ __all__ = [
|
||||||
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
||||||
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||||
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
||||||
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'is_moe_parallel_parameter'
|
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'is_moe_parallel_parameter', 'TensorDetector'
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,14 +5,16 @@ import torch
|
||||||
from torch.utils.checkpoint import check_backward_validity, detach_variable
|
from torch.utils.checkpoint import check_backward_validity, detach_variable
|
||||||
|
|
||||||
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
|
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
|
||||||
|
from .cuda import get_current_device
|
||||||
|
|
||||||
class CheckpointFunction(torch.autograd.Function):
|
class CheckpointFunction(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, run_function, *args):
|
def forward(ctx, run_function, activation_offload=False, *args):
|
||||||
check_backward_validity(args)
|
check_backward_validity(args)
|
||||||
ctx.run_function = run_function
|
ctx.run_function = run_function
|
||||||
|
ctx.activation_offload = activation_offload
|
||||||
|
ctx.device = get_current_device()
|
||||||
|
|
||||||
# preserve rng states
|
# preserve rng states
|
||||||
ctx.fwd_cpu_rng_state = torch.get_rng_state()
|
ctx.fwd_cpu_rng_state = torch.get_rng_state()
|
||||||
|
@ -32,7 +34,12 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
tensor_inputs = []
|
tensor_inputs = []
|
||||||
for i, arg in enumerate(args):
|
for i, arg in enumerate(args):
|
||||||
if torch.is_tensor(arg):
|
if torch.is_tensor(arg):
|
||||||
tensor_inputs.append(arg)
|
if ctx.activation_offload:
|
||||||
|
tmp = arg.detach().cpu()
|
||||||
|
tmp.requires_grad = arg.requires_grad
|
||||||
|
tensor_inputs.append(tmp)
|
||||||
|
else:
|
||||||
|
tensor_inputs.append(arg)
|
||||||
ctx.tensor_indices.append(i)
|
ctx.tensor_indices.append(i)
|
||||||
ctx.inputs.append(None)
|
ctx.inputs.append(None)
|
||||||
else:
|
else:
|
||||||
|
@ -70,8 +77,9 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
|
|
||||||
# Fill in inputs with appropriate saved tensors.
|
# Fill in inputs with appropriate saved tensors.
|
||||||
for i, idx in enumerate(tensor_indices):
|
for i, idx in enumerate(tensor_indices):
|
||||||
inputs[idx] = tensors[i]
|
tmp = tensors[i].detach().to(ctx.device)
|
||||||
|
tmp.requires_grad = tensors[i].requires_grad
|
||||||
|
inputs[idx] = tmp
|
||||||
detached_inputs = detach_variable(tuple(inputs))
|
detached_inputs = detach_variable(tuple(inputs))
|
||||||
if ctx.had_autocast_in_fwd:
|
if ctx.had_autocast_in_fwd:
|
||||||
with torch.enable_grad(), torch.cuda.amp.autocast():
|
with torch.enable_grad(), torch.cuda.amp.autocast():
|
||||||
|
@ -82,7 +90,6 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
|
|
||||||
if isinstance(outputs, torch.Tensor):
|
if isinstance(outputs, torch.Tensor):
|
||||||
outputs = (outputs,)
|
outputs = (outputs,)
|
||||||
|
|
||||||
# recover the rng states
|
# recover the rng states
|
||||||
torch.set_rng_state(bwd_cpu_rng_state)
|
torch.set_rng_state(bwd_cpu_rng_state)
|
||||||
for parallel_mode, state in bwd_seed_states.items():
|
for parallel_mode, state in bwd_seed_states.items():
|
||||||
|
@ -103,15 +110,14 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
|
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
|
||||||
for inp in detached_inputs)
|
for inp in detached_inputs)
|
||||||
|
return (None, None) + grads
|
||||||
return (None,) + grads
|
|
||||||
|
|
||||||
|
|
||||||
def checkpoint(function, *args):
|
def checkpoint(function, activation_offload ,*args):
|
||||||
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint
|
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint
|
||||||
|
|
||||||
:param function: Describe the forward pass function. It should know how to handle the input tuples.
|
:param function: Describe the forward pass function. It should know how to handle the input tuples.
|
||||||
:param args: Tuple containing the parameters of the function
|
:param args: Tuple containing the parameters of the function
|
||||||
:return: Output of running function with provided args
|
:return: Output of running function with provided args
|
||||||
"""
|
"""
|
||||||
return CheckpointFunction.apply(function, *args)
|
return CheckpointFunction.apply(function, activation_offload, *args)
|
||||||
|
|
|
@ -17,13 +17,14 @@ def forward(x, weight):
|
||||||
out_ = F.dropout(out, p=0.4, training=True)
|
out_ = F.dropout(out, p=0.4, training=True)
|
||||||
return out_
|
return out_
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.gpu
|
@pytest.mark.gpu
|
||||||
def test_activation_checkpointing():
|
@pytest.mark.parametrize("cpu_offload", [True, False])
|
||||||
add_seed(ParallelMode.GLOBAL, 1024)
|
def test_activation_checkpointing(cpu_offload):
|
||||||
|
if cpu_offload:
|
||||||
|
add_seed(ParallelMode.GLOBAL, 1024)
|
||||||
|
add_seed(ParallelMode.DATA, 1026)
|
||||||
set_mode(ParallelMode.GLOBAL)
|
set_mode(ParallelMode.GLOBAL)
|
||||||
global_cuda_rng_state = torch.cuda.get_rng_state()
|
global_cuda_rng_state = torch.cuda.get_rng_state()
|
||||||
add_seed(ParallelMode.DATA, 1026)
|
|
||||||
set_mode(ParallelMode.DATA)
|
set_mode(ParallelMode.DATA)
|
||||||
data_parallel_cuda_rng_state = torch.cuda.get_rng_state()
|
data_parallel_cuda_rng_state = torch.cuda.get_rng_state()
|
||||||
set_mode(ParallelMode.GLOBAL)
|
set_mode(ParallelMode.GLOBAL)
|
||||||
|
@ -49,13 +50,10 @@ def test_activation_checkpointing():
|
||||||
set_mode(ParallelMode.DATA)
|
set_mode(ParallelMode.DATA)
|
||||||
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||||
set_mode(ParallelMode.GLOBAL)
|
set_mode(ParallelMode.GLOBAL)
|
||||||
out = checkpoint(forward, data_, weight_)
|
out = checkpoint(forward, cpu_offload, data_, weight_)
|
||||||
loss = out.sum()
|
loss = out.sum()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
|
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_activation_checkpointing()
|
|
||||||
|
|
Loading…
Reference in New Issue