From 107b99ddb17621678607959bb8072ef43efc4b20 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 30 Mar 2022 09:38:44 +0800 Subject: [PATCH] [zero] dump memory stats for sharded model (#548) --- .../utils/memory_tracer/memstats_collector.py | 8 +++++ .../zero/sharded_model/sharded_model_v2.py | 30 +++++++++++++++++-- .../zero/sharded_optim/sharded_optim_v2.py | 15 ++++++---- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/colossalai/utils/memory_tracer/memstats_collector.py b/colossalai/utils/memory_tracer/memstats_collector.py index f6b613e24..bea97a5bf 100644 --- a/colossalai/utils/memory_tracer/memstats_collector.py +++ b/colossalai/utils/memory_tracer/memstats_collector.py @@ -45,10 +45,18 @@ class MemStatsCollector: def overall_cuda(self): return self._overall_cuda + @property + def model_data_cuda_GB(self): + return [elem / 1e9 for elem in self._model_data_cuda] + @property def model_data_cuda(self): return self._model_data_cuda + @property + def non_model_data_cuda_GB(self): + return [elem / 1e9 for elem in self.non_model_data_cuda] + @property def non_model_data_cuda(self): """Non model data stats diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index cfe98024b..68e69f301 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -11,6 +11,9 @@ from colossalai.engine.ophooks import register_ophooks_recursively from colossalai.engine.ophooks.zero_hook import ZeroHook from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.logging import get_dist_logger +from colossalai.utils import get_current_device +from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER +from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.model_data_memtracer import \ GLOBAL_MODEL_DATA_TRACER @@ -131,6 +134,29 @@ class ShardedModelV2(nn.Module): def cpu_offload(self): return self._cpu_offload + def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None: + """ + dummy memory tracer collected infomation to a file. + try: + # forward: model(inputs) + # backward: optimizer.backward() + except Exception as e: + model.dump_memory_stats() + exit(0) + """ + if self._use_memory_tracer: + self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0]) + if gpc.get_global_rank() == 0: + with open(filename, 'w+') as f: + f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n') + f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n') + f.write('model data\n') + f.write(str(self._memstats_collector.model_data_cuda_GB)) + f.write('\n') + f.write('non model data\n') + f.write(str(self._memstats_collector.non_model_data_cuda_GB)) + f.write('\n') + def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: if self._iter_cnter == 0 and self._memstats_collector: # the opeartion will affect the flag in ZeroHook @@ -147,6 +173,7 @@ class ShardedModelV2(nn.Module): def backward_by_grad(self, tensor, grad): torch.autograd.backward(tensors=tensor, grad_tensors=grad) + self._post_backward_operations() for ophook in self._ophook_list: ophook.post_iter() @@ -154,9 +181,6 @@ class ShardedModelV2(nn.Module): def _update_memstats(self): if self._iter_cnter == 0 and self._memstats_collector: self._memstats_collector.finish_collection() - self.logger.debug(f'model data cuda, {self._memstats_collector.model_data_cuda}') - self.logger.debug(f'non-model data cuda, {self._memstats_collector.non_model_data_cuda}') - if self._memstats_collector: self._memstats_collector.reset_sampling_cnter() # cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used. diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 9db62b717..90d908044 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -24,7 +24,6 @@ from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, col from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER - class OptimState(Enum): SCALED = 1 UNSCALED = 2 @@ -139,6 +138,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer): if self._use_memory_tracer: GLOBAL_MODEL_DATA_TRACER.register_optimizer(self) + self._use_memory_tracer = self.model.use_memory_tracer + if self._use_memory_tracer: + GLOBAL_MODEL_DATA_TRACER.register_optimizer(self) + def get_memory_usage(self) -> Tuple[int, int]: """ Get the memory usage of the optimizer. Including master_params (param fp32), @@ -190,13 +193,15 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # Now p.data is sharded # So optimizer states are sharded naturally - self._logger.debug(f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!", - ranks=[0]) + self._logger.debug( + f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!", + ranks=[0]) ret = self.optim.step(*args, **kwargs) - self._logger.debug(f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!", - ranks=[0]) + self._logger.debug( + f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!", + ranks=[0]) # Copy master param data (fp32) to payload of col_attr (fp16) # TODO() improve efficiency by gathering tensors into a chunk and transfering # a chunk.