diff --git a/colossalai/gemini/memory_tracer/model_data_memtracer.py b/colossalai/gemini/memory_tracer/model_data_memtracer.py index c228bdff4..3274486fd 100644 --- a/colossalai/gemini/memory_tracer/model_data_memtracer.py +++ b/colossalai/gemini/memory_tracer/model_data_memtracer.py @@ -1,6 +1,8 @@ -from colossalai.context.singleton_meta import SingletonMeta +from typing import Optional, Tuple + import torch -from typing import Tuple, Optional + +from colossalai.context.singleton_meta import SingletonMeta from colossalai.logging import DistributedLogger @@ -20,7 +22,7 @@ def colo_model_optimizer_usage(optim) -> Tuple[int, int]: def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]: - """ + """ Trace the model memory usage. Args: model (torch.nn.Module): a torch model @@ -106,15 +108,4 @@ class ModelDataTracer(metaclass=SingletonMeta): return self._get_mem_usage() -class CudaMemInfo(metaclass=SingletonMeta): - - def __init__(self) -> None: - self.model_data_list = [] - self.non_model_data_list = [] - self.unreleased_grad_flag = {} - self.unreleased_grad_volume = 0 - - GLOBAL_MODEL_DATA_TRACER = ModelDataTracer() - -GLOBAL_CUDA_MEM_INFO = CudaMemInfo() \ No newline at end of file diff --git a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py index 275a88335..dc204e352 100644 --- a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py @@ -1,8 +1,7 @@ import torch.nn from colossalai.gemini.memory_tracer import MemStats -from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO -from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemTracerHook, ParamMemTracerHook +from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.tensor.param_op_hook import ColoParamOpHookManager @@ -25,9 +24,10 @@ class RuntimeMemTracer(): super().__init__() self.module = module self.dtype = dtype + self._gradstat = GradMemStats() self._memstats = MemStats() - self.param_op_hook = ParamMemTracerHook(self._memstats) - self.grad_hook = GradMemTracerHook(module) + self.param_op_hook = ParamMemTracerHook(self._memstats, self._gradstat) + self.grad_hook = GradMemTracerHook(self._gradstat) self.cpu_param_data_dict = {} for p in module.parameters(): @@ -58,7 +58,7 @@ class RuntimeMemTracer(): def _pre_forward(self): self._clear_cuda_mem_info() self._backup_params() - self.grad_hook.register_grad_hook() + self.grad_hook.register_grad_hook(self.module) self.param_op_hook.mem_monitor.start() def forward(self, *args, **kwargs): @@ -78,17 +78,12 @@ class RuntimeMemTracer(): cuda_volume = self.param_op_hook.mem_monitor.finish() self._memstats.append_model_data('cuda', cuda_volume) self._memstats.append_non_model_data('cuda') - # last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1] - # GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data) self.grad_hook.remove_grad_hook() self._restore_params() def _clear_cuda_mem_info(self): - # GLOBAL_CUDA_MEM_INFO.model_data_list.clear() - # GLOBAL_CUDA_MEM_INFO.non_model_data_list.clear() self._memstats.clear() - GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag.clear() - GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume = 0 + self._gradstat.clear() def _cast_buffers_to_cuda_dtype(self): for buffer in self.module.buffers(): diff --git a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py index 55362f888..465c13747 100644 --- a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py +++ b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py @@ -6,7 +6,6 @@ from typing import List import torch from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor -from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO from colossalai.gemini.tensor_utils import alloc_storage, free_storage from colossalai.tensor.param_op_hook import ColoParamOpHook @@ -16,23 +15,34 @@ class TrainingPhase(Enum): BACKWARD = 1 +class GradMemStats(): + + def __init__(self) -> None: + self.unreleased_grad_flag = {} + self.unreleased_grad_volume = 0 + + def clear(self): + self.unreleased_grad_flag.clear() + self.unreleased_grad_volume = 0 + + class GradMemTracerHook(): - def __init__(self, module: torch.nn.Module): - self.module = module + def __init__(self, grad_stats: GradMemStats): self.grad_hook_list = [] + self._grad_stats = grad_stats def grad_handle(self, p, grad): - assert GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] + assert self._grad_stats.unreleased_grad_flag[p] free_storage(grad) - GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume -= grad.numel() * grad.element_size() - GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False + self._grad_stats.unreleased_grad_volume -= grad.numel() * grad.element_size() + self._grad_stats.unreleased_grad_flag[p] = False - def register_grad_hook(self): - for p in self.module.parameters(): + def register_grad_hook(self, module: torch.nn.Module): + for p in module.parameters(): if p.requires_grad: self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p))) - GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False + self._grad_stats.unreleased_grad_flag[p] = False def remove_grad_hook(self): for hook in self.grad_hook_list: @@ -41,10 +51,11 @@ class GradMemTracerHook(): class ParamMemTracerHook(ColoParamOpHook): - def __init__(self, memstats) -> None: + def __init__(self, memstats, gradstats: GradMemStats) -> None: super().__init__() self._training_phase = TrainingPhase.FORWARD self._memstats = memstats + self._grad_stats = gradstats self.mem_monitor = SyncCudaMemoryMonitor() def _free_cuda_params(self, params): @@ -67,24 +78,21 @@ class ParamMemTracerHook(ColoParamOpHook): alloc_storage(p.data) def sample_model_data(self, params): - data_volume = GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume + data_volume = self._grad_stats.unreleased_grad_volume for p in params: cur_model_data_volume = p.data.numel() * p.data.element_size() data_volume += cur_model_data_volume if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad: # add param.grad, actually param.grad is None in this time data_volume += cur_model_data_volume - if not GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p]: - GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume += cur_model_data_volume - GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = True - # GLOBAL_CUDA_MEM_INFO.model_data_list.append(data_volume) + if not self._grad_stats.unreleased_grad_flag[p]: + self._grad_stats.unreleased_grad_volume += cur_model_data_volume + self._grad_stats.unreleased_grad_flag[p] = True self._memstats.append_model_data('cuda', data_volume) def pre_op(self, params): cuda_volume = self.mem_monitor.finish() self._memstats.append_model_data('cuda', cuda_volume) - # if len(GLOBAL_CUDA_MEM_INFO.model_data_list): - # GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - GLOBAL_CUDA_MEM_INFO.model_data_list[-1]) self._allocate_params_on_cuda(params) self.sample_model_data(params) self.mem_monitor.start()