diff --git a/colossalai/gemini/memory_tracer/param_tracer_wrapper.py b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py similarity index 83% rename from colossalai/gemini/memory_tracer/param_tracer_wrapper.py rename to colossalai/gemini/memory_tracer/runtime_mem_tracer.py index f69df73e3..829e0d4d4 100644 --- a/colossalai/gemini/memory_tracer/param_tracer_wrapper.py +++ b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py @@ -1,13 +1,14 @@ import torch.nn -from colossalai.tensor.param_op_hook import ParamOpHookManager -from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook, GradHook from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO +from colossalai.gemini.ophooks.param_trace_hook import GradHook, ParamTracerHook from colossalai.nn.parallel.data_parallel import _cast_float +from colossalai.tensor.param_op_hook import ParamOpHookManager -__all__ = ['ParamTracerWrapper'] +__all__ = ['RuntimeMemTracer'] -class ParamTracerWrapper(): + +class RuntimeMemTracer(): def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half): super().__init__() @@ -25,12 +26,18 @@ class ParamTracerWrapper(): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) - def _save_param_data_on_cpu(self): + def _backup_params(self): + """ + The function is called before forward. Backup model params on cpu. + """ for p in self.module.parameters(): self.cpu_param_data_dict[p] = torch.empty(p.data.shape, dtype=self.dtype, device="cpu") self.cpu_param_data_dict[p].copy_(p.data) - def _restore_param_data(self): + def _restore_params(self): + """ + This function is called after backward. Restore model params. + """ for p in self.module.parameters(): p.data = torch.empty(p.data.shape, dtype=self.dtype, device="cpu", requires_grad=p.data.requires_grad) p.data.copy_(self.cpu_param_data_dict[p]) @@ -38,7 +45,7 @@ class ParamTracerWrapper(): def _pre_forward(self): self._clear_cuda_mem_info() - self._save_param_data_on_cpu() + self._backup_params() self.grad_hook.register_grad_hook() self.param_op_hook.mem_monitor.start() @@ -60,7 +67,7 @@ class ParamTracerWrapper(): last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1] GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data) self.grad_hook.remove_grad_hook() - self._restore_param_data() + self._restore_params() def _clear_cuda_mem_info(self): GLOBAL_CUDA_MEM_INFO.model_data_list.clear() @@ -72,4 +79,4 @@ class ParamTracerWrapper(): for buffer in self.module.buffers(): buffer.data = buffer.cuda() if torch.is_floating_point(buffer): - buffer.data = buffer.data.to(self.dtype) \ No newline at end of file + buffer.data = buffer.data.to(self.dtype) diff --git a/tests/test_gemini/test_param_tracer.py b/tests/test_gemini/test_runtime_mem_tracer.py similarity index 74% rename from tests/test_gemini/test_param_tracer.py rename to tests/test_gemini/test_runtime_mem_tracer.py index 7e4c6dff5..0b112f66f 100644 --- a/tests/test_gemini/test_param_tracer.py +++ b/tests/test_gemini/test_runtime_mem_tracer.py @@ -1,11 +1,15 @@ +from copy import deepcopy + import numpy as np import torch -from colossalai.gemini.memory_tracer.param_tracer_wrapper import ParamTracerWrapper from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO +from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs + def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torch.half): with torch.cuda.amp.autocast(enabled=enable_autocast): if criterion: @@ -16,9 +20,9 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc loss = loss.to(dtype) model.backward(loss) + def run_param_wrapper_testing(): test_models = ['simple_net', 'repeated_computed_layers', 'nested_model'] - for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, _, _, criterion = get_components_func() @@ -26,7 +30,8 @@ def run_param_wrapper_testing(): with ColoInitContext(device=torch.device('cpu')): model = model_builder(checkpoint=False) - model = ParamTracerWrapper(model) + model_bk = deepcopy(model) + runtime_mem_tracer = RuntimeMemTracer(model) for i, (data, label) in enumerate(train_dataloader): if i > 1: @@ -34,15 +39,17 @@ def run_param_wrapper_testing(): data = data.cuda() label = label.cuda() - run_fwd_bwd(model, data, label, criterion, False) + run_fwd_bwd(runtime_mem_tracer, data, label, criterion, False) - cuda_non_model_data_list = np.array(GLOBAL_CUDA_MEM_INFO.non_model_data_list) / 1024 ** 2 + for p1, p2 in zip(model_bk.parameters(), model.parameters()): + torch.allclose(p1.to(torch.half), p2) + + cuda_non_model_data_list = np.array(GLOBAL_CUDA_MEM_INFO.non_model_data_list) / 1024**2 print("cuda_non_model_data_list", len(cuda_non_model_data_list)) # print(GLOBAL_CUDA_MEM_INFO.non_model_data_list) del model - if __name__ == '__main__': - run_param_wrapper_testing() \ No newline at end of file + run_param_wrapper_testing()