diff --git a/colossalai/gemini/memory_tracer/model_data_memtracer.py b/colossalai/gemini/memory_tracer/model_data_memtracer.py index 98228892d..c228bdff4 100644 --- a/colossalai/gemini/memory_tracer/model_data_memtracer.py +++ b/colossalai/gemini/memory_tracer/model_data_memtracer.py @@ -106,4 +106,15 @@ class ModelDataTracer(metaclass=SingletonMeta): return self._get_mem_usage() +class CudaMemInfo(metaclass=SingletonMeta): + + def __init__(self) -> None: + self.model_data_list = [] + self.non_model_data_list = [] + self.unreleased_grad_flag = {} + self.unreleased_grad_volume = 0 + + GLOBAL_MODEL_DATA_TRACER = ModelDataTracer() + +GLOBAL_CUDA_MEM_INFO = CudaMemInfo() \ No newline at end of file diff --git a/colossalai/gemini/memory_tracer/param_tracer_wrapper.py b/colossalai/gemini/memory_tracer/param_tracer_wrapper.py index 50cc1451e..f69df73e3 100644 --- a/colossalai/gemini/memory_tracer/param_tracer_wrapper.py +++ b/colossalai/gemini/memory_tracer/param_tracer_wrapper.py @@ -1,11 +1,9 @@ import torch.nn from colossalai.tensor.param_op_hook import ParamOpHookManager -from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook -from colossalai.gemini.tensor_utils import free_storage +from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook, GradHook +from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO from colossalai.nn.parallel.data_parallel import _cast_float -from functools import partial - __all__ = ['ParamTracerWrapper'] @@ -15,22 +13,33 @@ class ParamTracerWrapper(): super().__init__() self.module = module self.dtype = dtype - self.param_op_hook = ParamTracerHook(dtype) + self.param_op_hook = ParamTracerHook() + self.grad_hook = GradHook(module) + self.cpu_param_data_dict = {} for p in module.parameters(): p.data = p.data.to(dtype) - if p.requires_grad: - p.register_hook(partial(self.grad_handle)) self._cast_buffers_to_cuda_dtype() def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) - def grad_handle(self, grad): - free_storage(grad) + def _save_param_data_on_cpu(self): + for p in self.module.parameters(): + self.cpu_param_data_dict[p] = torch.empty(p.data.shape, dtype=self.dtype, device="cpu") + self.cpu_param_data_dict[p].copy_(p.data) + + def _restore_param_data(self): + for p in self.module.parameters(): + p.data = torch.empty(p.data.shape, dtype=self.dtype, device="cpu", requires_grad=p.data.requires_grad) + p.data.copy_(self.cpu_param_data_dict[p]) + self.cpu_param_data_dict.clear() def _pre_forward(self): + self._clear_cuda_mem_info() + self._save_param_data_on_cpu() + self.grad_hook.register_grad_hook() self.param_op_hook.mem_monitor.start() def forward(self, *args, **kwargs): @@ -48,8 +57,16 @@ class ParamTracerWrapper(): def _post_backward(self): cuda_volume = self.param_op_hook.mem_monitor.finish() - last_model_data = self.param_op_hook._model_data_list[-1] - self.param_op_hook._non_model_data_list.append(cuda_volume - last_model_data) + last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1] + GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data) + self.grad_hook.remove_grad_hook() + self._restore_param_data() + + def _clear_cuda_mem_info(self): + GLOBAL_CUDA_MEM_INFO.model_data_list.clear() + GLOBAL_CUDA_MEM_INFO.non_model_data_list.clear() + GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag.clear() + GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume = 0 def _cast_buffers_to_cuda_dtype(self): for buffer in self.module.buffers(): diff --git a/colossalai/gemini/ophooks/param_trace_hook.py b/colossalai/gemini/ophooks/param_trace_hook.py index aef2cdbd7..678927d78 100644 --- a/colossalai/gemini/ophooks/param_trace_hook.py +++ b/colossalai/gemini/ophooks/param_trace_hook.py @@ -8,6 +8,7 @@ import torch from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.tensor.param_op_hook import ParamOpHook from colossalai.gemini.tensor_utils import free_storage, alloc_storage +from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO class TrainingPhase(Enum): @@ -15,42 +16,69 @@ class TrainingPhase(Enum): BACKWARD = 1 +class GradHook(): + def __init__(self, module: torch.nn.Module): + self.module = module + self.grad_hook_list = [] + + def grad_handle(self, p, grad): + assert GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] + free_storage(grad) + GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume -= grad.numel() * grad.element_size() + GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False + + def register_grad_hook(self): + for p in self.module.parameters(): + if p.requires_grad: + self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p))) + GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False + + def remove_grad_hook(self): + for hook in self.grad_hook_list: + hook.remove() + + class ParamTracerHook(ParamOpHook): - def __init__(self, dtype: torch.dtype = torch.half) -> None: + def __init__(self) -> None: super().__init__() self._training_phase = TrainingPhase.FORWARD self.mem_monitor = SyncCudaMemoryMonitor() - self._non_model_data_list = [] - self._model_data_list = [] - self.dtype = dtype def _free_cuda_params(self, params): for p in params: + if p.data.device.type == "cpu": + raise NotImplementedError("Only free cuda memory") free_storage(p.data) def _allocate_params_on_cuda(self, params): for p in params: cur_dev = p.data.device.type if cur_dev == "cpu": - # p.data = p.data.to("cuda") - p.data = torch.randn(p.data.shape, device="cuda", dtype=self.dtype) + if p.grad is not None and p.grad.device.type == "cpu": + raise NotImplementedError("Only run in forward propagation") + p.data = torch.empty(p.data.shape, device="cuda", dtype=p.data.dtype, + requires_grad=p.data.requires_grad) elif cur_dev == "cuda": alloc_storage(p.data) def sample_model_data(self, params): - data_volume = 0 + data_volume = GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume for p in params: - data_volume += p.data.numel() * p.data.element_size() - if self._training_phase == TrainingPhase.BACKWARD: - # add param.grad, actually param.grad is None in this time - data_volume *= 2 - self._model_data_list.append(data_volume) + cur_model_data_volume = p.data.numel() * p.data.element_size() + data_volume += cur_model_data_volume + if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad: + # add param.grad, actually param.grad is None in this time + data_volume += cur_model_data_volume + if not GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p]: + GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume += cur_model_data_volume + GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = True + GLOBAL_CUDA_MEM_INFO.model_data_list.append(data_volume) def pre_op(self, params): cuda_volume = self.mem_monitor.finish() - if len(self._model_data_list): - self._non_model_data_list.append(cuda_volume - self._model_data_list[-1]) + if len(GLOBAL_CUDA_MEM_INFO.model_data_list): + GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - GLOBAL_CUDA_MEM_INFO.model_data_list[-1]) self._allocate_params_on_cuda(params) self.sample_model_data(params) self.mem_monitor.start() diff --git a/tests/test_gemini/test_param_tracer.py b/tests/test_gemini/test_param_tracer.py index d82778271..7e4c6dff5 100644 --- a/tests/test_gemini/test_param_tracer.py +++ b/tests/test_gemini/test_param_tracer.py @@ -2,6 +2,7 @@ import numpy as np import torch from colossalai.gemini.memory_tracer.param_tracer_wrapper import ParamTracerWrapper +from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO from colossalai.utils.model.colo_init_context import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs @@ -35,9 +36,9 @@ def run_param_wrapper_testing(): run_fwd_bwd(model, data, label, criterion, False) - cuda_non_model_data_list = np.array(model.param_op_hook._non_model_data_list) / 1024 ** 2 + cuda_non_model_data_list = np.array(GLOBAL_CUDA_MEM_INFO.non_model_data_list) / 1024 ** 2 print("cuda_non_model_data_list", len(cuda_non_model_data_list)) - # print(model.param_op_hook._non_model_data_list) + # print(GLOBAL_CUDA_MEM_INFO.non_model_data_list) del model