diff --git a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py index 829e0d4d4..ead95535e 100644 --- a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py @@ -1,7 +1,7 @@ import torch.nn from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO -from colossalai.gemini.ophooks.param_trace_hook import GradHook, ParamTracerHook +from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemTracerHook, ParamMemTracerHook from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.tensor.param_op_hook import ParamOpHookManager @@ -14,8 +14,8 @@ class RuntimeMemTracer(): super().__init__() self.module = module self.dtype = dtype - self.param_op_hook = ParamTracerHook() - self.grad_hook = GradHook(module) + self.param_op_hook = ParamMemTracerHook() + self.grad_hook = GradMemTracerHook(module) self.cpu_param_data_dict = {} for p in module.parameters(): diff --git a/colossalai/gemini/ophooks/_shard_grad_ophook.py b/colossalai/gemini/ophooks/_shard_grad_ophook.py index 582f95802..5115ff74d 100644 --- a/colossalai/gemini/ophooks/_shard_grad_ophook.py +++ b/colossalai/gemini/ophooks/_shard_grad_ophook.py @@ -1,11 +1,12 @@ import torch + from colossalai.registry import OPHOOKS from . import BaseOpHook @OPHOOKS.register_module -class ShardGradHook(BaseOpHook): +class ShardGradMemTracerHook(BaseOpHook): """ A hook to process sharded param before and afther FWD and BWD operator executing. """ diff --git a/colossalai/gemini/ophooks/mem_trace_hook.py b/colossalai/gemini/ophooks/mem_trace_hook.py deleted file mode 100644 index 697655259..000000000 --- a/colossalai/gemini/ophooks/mem_trace_hook.py +++ /dev/null @@ -1,100 +0,0 @@ -import torch - -from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor -from colossalai.gemini.ophooks import BaseOpHook - - -class MemTracerOpHook(BaseOpHook): - """ - TODO() what if parameters are sharded by multiple submodules. - register buff on its father node - """ - - def __init__(self): - super().__init__() - self.mem_monitor = SyncCudaMemoryMonitor() - self._cur_non_model_data_vol = 0 - self._non_model_data_list = [] - self._cur_model_data_vol = 0 - - def _move_module_to_dev(self, module, dev: str) -> int: - """ - move module to target dev - Args: - module (torch.nn.Module): a PyTorch module - dev (torch.device): the target device - Returns: - int: the data volume of this module on the cuda - """ - assert isinstance(dev, str), f"device should be a str not torch.device" - comm_volume = 0 - for p in module.parameters(): - if p.data.device.type != dev: - p.data = p.data.to(dev) - comm_volume += p.data.numel() * p.data.element_size() - if p.grad is not None: - if p.grad.device.type != dev: - p.grad = p.grad.to(dev) - comm_volume += p.grad.numel() * p.grad.element_size() - - for buf in module.buffers(): - if buf.device.type != dev: - buf.data = buf.data.to(dev) - comm_volume += buf.data.numel() * buf.data.element_size() - - if dev == 'cuda': - self._cur_model_data_vol = comm_volume - - return comm_volume - - def pre_fwd_exec(self, module: torch.nn.Module, *args): - if module.training: - cuda_volume = self.mem_monitor.finish() - comm_volume = self._move_module_to_dev(module, 'cuda') - self.mem_monitor.start() - # print(f'FWD PRE {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB') - - def post_fwd_exec(self, module: torch.nn.Module, *args): - if module.training: - cuda_volume = self.mem_monitor.finish() - comm_volume = self._move_module_to_dev(module, 'cpu') - self._non_model_data_list.append(cuda_volume - comm_volume) - # print(f'FWD POST {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB') - - def pre_bwd_exec(self, module: torch.nn.Module, input, output): - assert isinstance(module, torch.nn.Module) - if module.training: - cuda_volume = self.mem_monitor.finish() - self._move_module_to_dev(module, 'cuda') - self.mem_monitor.start() - # print(f'BWD PRE {module.__class__.__name__}') - - def post_bwd_exec(self, module: torch.nn.Module, input): - # bwd Op will generate grad. comm_volume is grad + data volume on cuda. - assert isinstance(module, torch.nn.Module) - if module.training: - cuda_volume = self.mem_monitor.finish() - comm_volume = self._move_module_to_dev(module, 'cpu') - self._non_model_data_list.append(cuda_volume - comm_volume) - # print(f'BWD POST {module.__class__.__name__} {cuda_volume / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB') - - def pre_iter(self): - pass - - def post_iter(self): - self.mem_monitor.finish() - # print(f'post_iter') - - def print_non_model_data(self): - print(self._non_model_data_list) - - def save_results(self, filename): - self.mem_monitor.save(filename) - - def show_mem_stats(self): - start_timestamp = min(self.mem_monitor.time_stamps) - self.mem_monitor.time_stamps = [elem - start_timestamp for elem in self.mem_monitor.time_stamps] - min_mem_used = min(self.mem_monitor.mem_stats) - self.mem_monitor.mem_stats = [elem - min_mem_used for elem in self.mem_monitor.mem_stats] - print(self.mem_monitor.time_stamps) - print(self.mem_monitor.mem_stats) diff --git a/colossalai/gemini/ophooks/param_trace_hook.py b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py similarity index 93% rename from colossalai/gemini/ophooks/param_trace_hook.py rename to colossalai/gemini/ophooks/runtime_mem_tracer_hook.py index 678927d78..5f155f085 100644 --- a/colossalai/gemini/ophooks/param_trace_hook.py +++ b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py @@ -6,9 +6,9 @@ from typing import List import torch from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor -from colossalai.tensor.param_op_hook import ParamOpHook -from colossalai.gemini.tensor_utils import free_storage, alloc_storage from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO +from colossalai.gemini.tensor_utils import alloc_storage, free_storage +from colossalai.tensor.param_op_hook import ParamOpHook class TrainingPhase(Enum): @@ -16,7 +16,8 @@ class TrainingPhase(Enum): BACKWARD = 1 -class GradHook(): +class GradMemTracerHook(): + def __init__(self, module: torch.nn.Module): self.module = module self.grad_hook_list = [] @@ -38,7 +39,7 @@ class GradHook(): hook.remove() -class ParamTracerHook(ParamOpHook): +class ParamMemTracerHook(ParamOpHook): def __init__(self) -> None: super().__init__() @@ -57,7 +58,9 @@ class ParamTracerHook(ParamOpHook): if cur_dev == "cpu": if p.grad is not None and p.grad.device.type == "cpu": raise NotImplementedError("Only run in forward propagation") - p.data = torch.empty(p.data.shape, device="cuda", dtype=p.data.dtype, + p.data = torch.empty(p.data.shape, + device="cuda", + dtype=p.data.dtype, requires_grad=p.data.requires_grad) elif cur_dev == "cuda": alloc_storage(p.data) diff --git a/tests/test_gemini/test_runtime_mem_tracer.py b/tests/test_gemini/test_runtime_mem_tracer.py index 47f6e432b..2806b8cb0 100644 --- a/tests/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_gemini/test_runtime_mem_tracer.py @@ -29,7 +29,7 @@ def test_runtime_mem_tracer(): model_builder, train_dataloader, _, _, criterion = get_components_func() with ColoInitContext(device=torch.device('cpu')): - model = model_builder(checkpoint=True) + model = model_builder(checkpoint=False) model_bk = deepcopy(model) runtime_mem_tracer = RuntimeMemTracer(model) @@ -47,7 +47,7 @@ def test_runtime_mem_tracer(): cuda_non_model_data_list = np.array(GLOBAL_CUDA_MEM_INFO.non_model_data_list) / 1024**2 print("cuda_non_model_data_list", len(cuda_non_model_data_list)) - # print(GLOBAL_CUDA_MEM_INFO.non_model_data_list) + print(GLOBAL_CUDA_MEM_INFO.non_model_data_list) del model