diff --git a/colossalai/gemini/memory_tracer/param_tracer_wrapper.py b/colossalai/gemini/memory_tracer/param_tracer_wrapper.py index b6b26fe9a..50cc1451e 100644 --- a/colossalai/gemini/memory_tracer/param_tracer_wrapper.py +++ b/colossalai/gemini/memory_tracer/param_tracer_wrapper.py @@ -1,9 +1,11 @@ import torch.nn -from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook +from colossalai.gemini.tensor_utils import free_storage from colossalai.nn.parallel.data_parallel import _cast_float +from functools import partial + __all__ = ['ParamTracerWrapper'] @@ -13,17 +15,21 @@ class ParamTracerWrapper(): super().__init__() self.module = module self.dtype = dtype - self.param_op_hook = ParamTracerHook() + self.param_op_hook = ParamTracerHook(dtype) for p in module.parameters(): - assert isinstance(p, ColoParameter) p.data = p.data.to(dtype) + if p.requires_grad: + p.register_hook(partial(self.grad_handle)) self._cast_buffers_to_cuda_dtype() def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) + def grad_handle(self, grad): + free_storage(grad) + def _pre_forward(self): self.param_op_hook.mem_monitor.start() diff --git a/colossalai/gemini/ophooks/param_trace_hook.py b/colossalai/gemini/ophooks/param_trace_hook.py index a8fd5df52..aef2cdbd7 100644 --- a/colossalai/gemini/ophooks/param_trace_hook.py +++ b/colossalai/gemini/ophooks/param_trace_hook.py @@ -7,6 +7,7 @@ import torch from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.tensor.param_op_hook import ParamOpHook +from colossalai.gemini.tensor_utils import free_storage, alloc_storage class TrainingPhase(Enum): @@ -16,25 +17,26 @@ class TrainingPhase(Enum): class ParamTracerHook(ParamOpHook): - def __init__(self) -> None: + def __init__(self, dtype: torch.dtype = torch.half) -> None: super().__init__() self._training_phase = TrainingPhase.FORWARD self.mem_monitor = SyncCudaMemoryMonitor() self._non_model_data_list = [] self._model_data_list = [] + self.dtype = dtype - def _move_params_to_dev(self, params, dev: str) -> int: - assert isinstance(dev, str), f"device should be a str not torch.device" - comm_volume = 0 + def _free_cuda_params(self, params): for p in params: - if p.data.device.type != dev: - p.data = p.data.to(dev) - comm_volume += p.data.numel() * p.data.element_size() - if p.grad is not None: - if p.grad.device.type != dev: - p.grad = p.grad.to(dev) - comm_volume += p.grad.numel() * p.grad.element_size() - return comm_volume + free_storage(p.data) + + def _allocate_params_on_cuda(self, params): + for p in params: + cur_dev = p.data.device.type + if cur_dev == "cpu": + # p.data = p.data.to("cuda") + p.data = torch.randn(p.data.shape, device="cuda", dtype=self.dtype) + elif cur_dev == "cuda": + alloc_storage(p.data) def sample_model_data(self, params): data_volume = 0 @@ -49,12 +51,12 @@ class ParamTracerHook(ParamOpHook): cuda_volume = self.mem_monitor.finish() if len(self._model_data_list): self._non_model_data_list.append(cuda_volume - self._model_data_list[-1]) - self._move_params_to_dev(params, 'cuda') + self._allocate_params_on_cuda(params) self.sample_model_data(params) self.mem_monitor.start() def post_op(self, params): - self._move_params_to_dev(params, 'cpu') + self._free_cuda_params(params) def pre_forward(self, params: List[torch.Tensor]) -> None: self.pre_op(params) diff --git a/colossalai/gemini/tensor_utils.py b/colossalai/gemini/tensor_utils.py index f2d69046e..bcc159f99 100644 --- a/colossalai/gemini/tensor_utils.py +++ b/colossalai/gemini/tensor_utils.py @@ -3,6 +3,20 @@ from colossalai.gemini.stateful_tensor import StatefulTensor from typing import Union, Tuple +def is_storage_empty(tensor: torch.Tensor) -> bool: + return tensor.storage().size() == 0 + + +def free_storage(tensor: torch.Tensor) -> None: + if not is_storage_empty(tensor): + tensor.storage().resize_(0) + + +def alloc_storage(tensor: torch.Tensor) -> None: + if is_storage_empty(tensor): + tensor.storage().resize_(tensor.numel()) + + def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]: if isinstance(tensor, StatefulTensor): t = tensor.payload diff --git a/tests/test_gemini/test_param_tracer.py b/tests/test_gemini/test_param_tracer.py index 79f311cb5..d82778271 100644 --- a/tests/test_gemini/test_param_tracer.py +++ b/tests/test_gemini/test_param_tracer.py @@ -16,7 +16,7 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc model.backward(loss) def run_param_wrapper_testing(): - test_models = ['simple_net'] + test_models = ['simple_net', 'repeated_computed_layers', 'nested_model'] for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name)