[Gemini] free and allocate cuda memory by tensor.storage, add grad hook (#2040)

pull/2047/head
Zihao 2 years ago committed by GitHub
parent 1e885329f4
commit 6a9158f1fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,9 +1,11 @@
import torch.nn import torch.nn
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook
from colossalai.gemini.tensor_utils import free_storage
from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.nn.parallel.data_parallel import _cast_float
from functools import partial
__all__ = ['ParamTracerWrapper'] __all__ = ['ParamTracerWrapper']
@ -13,17 +15,21 @@ class ParamTracerWrapper():
super().__init__() super().__init__()
self.module = module self.module = module
self.dtype = dtype self.dtype = dtype
self.param_op_hook = ParamTracerHook() self.param_op_hook = ParamTracerHook(dtype)
for p in module.parameters(): for p in module.parameters():
assert isinstance(p, ColoParameter)
p.data = p.data.to(dtype) p.data = p.data.to(dtype)
if p.requires_grad:
p.register_hook(partial(self.grad_handle))
self._cast_buffers_to_cuda_dtype() self._cast_buffers_to_cuda_dtype()
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
def grad_handle(self, grad):
free_storage(grad)
def _pre_forward(self): def _pre_forward(self):
self.param_op_hook.mem_monitor.start() self.param_op_hook.mem_monitor.start()

@ -7,6 +7,7 @@ import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.tensor.param_op_hook import ParamOpHook from colossalai.tensor.param_op_hook import ParamOpHook
from colossalai.gemini.tensor_utils import free_storage, alloc_storage
class TrainingPhase(Enum): class TrainingPhase(Enum):
@ -16,25 +17,26 @@ class TrainingPhase(Enum):
class ParamTracerHook(ParamOpHook): class ParamTracerHook(ParamOpHook):
def __init__(self) -> None: def __init__(self, dtype: torch.dtype = torch.half) -> None:
super().__init__() super().__init__()
self._training_phase = TrainingPhase.FORWARD self._training_phase = TrainingPhase.FORWARD
self.mem_monitor = SyncCudaMemoryMonitor() self.mem_monitor = SyncCudaMemoryMonitor()
self._non_model_data_list = [] self._non_model_data_list = []
self._model_data_list = [] self._model_data_list = []
self.dtype = dtype
def _move_params_to_dev(self, params, dev: str) -> int: def _free_cuda_params(self, params):
assert isinstance(dev, str), f"device should be a str not torch.device"
comm_volume = 0
for p in params: for p in params:
if p.data.device.type != dev: free_storage(p.data)
p.data = p.data.to(dev)
comm_volume += p.data.numel() * p.data.element_size() def _allocate_params_on_cuda(self, params):
if p.grad is not None: for p in params:
if p.grad.device.type != dev: cur_dev = p.data.device.type
p.grad = p.grad.to(dev) if cur_dev == "cpu":
comm_volume += p.grad.numel() * p.grad.element_size() # p.data = p.data.to("cuda")
return comm_volume p.data = torch.randn(p.data.shape, device="cuda", dtype=self.dtype)
elif cur_dev == "cuda":
alloc_storage(p.data)
def sample_model_data(self, params): def sample_model_data(self, params):
data_volume = 0 data_volume = 0
@ -49,12 +51,12 @@ class ParamTracerHook(ParamOpHook):
cuda_volume = self.mem_monitor.finish() cuda_volume = self.mem_monitor.finish()
if len(self._model_data_list): if len(self._model_data_list):
self._non_model_data_list.append(cuda_volume - self._model_data_list[-1]) self._non_model_data_list.append(cuda_volume - self._model_data_list[-1])
self._move_params_to_dev(params, 'cuda') self._allocate_params_on_cuda(params)
self.sample_model_data(params) self.sample_model_data(params)
self.mem_monitor.start() self.mem_monitor.start()
def post_op(self, params): def post_op(self, params):
self._move_params_to_dev(params, 'cpu') self._free_cuda_params(params)
def pre_forward(self, params: List[torch.Tensor]) -> None: def pre_forward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params) self.pre_op(params)

@ -3,6 +3,20 @@ from colossalai.gemini.stateful_tensor import StatefulTensor
from typing import Union, Tuple from typing import Union, Tuple
def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0
def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)
def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]: def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
if isinstance(tensor, StatefulTensor): if isinstance(tensor, StatefulTensor):
t = tensor.payload t = tensor.payload

@ -16,7 +16,7 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc
model.backward(loss) model.backward(loss)
def run_param_wrapper_testing(): def run_param_wrapper_testing():
test_models = ['simple_net'] test_models = ['simple_net', 'repeated_computed_layers', 'nested_model']
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)

Loading…
Cancel
Save