Browse Source

[zero] dump memory stats for sharded model (#548)

pull/552/head
Jiarui Fang 3 years ago committed by GitHub
parent
commit
107b99ddb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      colossalai/utils/memory_tracer/memstats_collector.py
  2. 30
      colossalai/zero/sharded_model/sharded_model_v2.py
  3. 11
      colossalai/zero/sharded_optim/sharded_optim_v2.py

8
colossalai/utils/memory_tracer/memstats_collector.py

@ -45,10 +45,18 @@ class MemStatsCollector:
def overall_cuda(self):
return self._overall_cuda
@property
def model_data_cuda_GB(self):
return [elem / 1e9 for elem in self._model_data_cuda]
@property
def model_data_cuda(self):
return self._model_data_cuda
@property
def non_model_data_cuda_GB(self):
return [elem / 1e9 for elem in self.non_model_data_cuda]
@property
def non_model_data_cuda(self):
"""Non model data stats

30
colossalai/zero/sharded_model/sharded_model_v2.py

@ -11,6 +11,9 @@ from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
@ -131,6 +134,29 @@ class ShardedModelV2(nn.Module):
def cpu_offload(self):
return self._cpu_offload
def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None:
"""
dummy memory tracer collected infomation to a file.
try:
# forward: model(inputs)
# backward: optimizer.backward()
except Exception as e:
model.dump_memory_stats()
exit(0)
"""
if self._use_memory_tracer:
self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0])
if gpc.get_global_rank() == 0:
with open(filename, 'w+') as f:
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n')
f.write('model data\n')
f.write(str(self._memstats_collector.model_data_cuda_GB))
f.write('\n')
f.write('non model data\n')
f.write(str(self._memstats_collector.non_model_data_cuda_GB))
f.write('\n')
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self._iter_cnter == 0 and self._memstats_collector:
# the opeartion will affect the flag in ZeroHook
@ -147,6 +173,7 @@ class ShardedModelV2(nn.Module):
def backward_by_grad(self, tensor, grad):
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
self._post_backward_operations()
for ophook in self._ophook_list:
ophook.post_iter()
@ -154,9 +181,6 @@ class ShardedModelV2(nn.Module):
def _update_memstats(self):
if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection()
self.logger.debug(f'model data cuda, {self._memstats_collector.model_data_cuda}')
self.logger.debug(f'non-model data cuda, {self._memstats_collector.non_model_data_cuda}')
if self._memstats_collector:
self._memstats_collector.reset_sampling_cnter()
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.

11
colossalai/zero/sharded_optim/sharded_optim_v2.py

@ -24,7 +24,6 @@ from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, col
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
class OptimState(Enum):
SCALED = 1
UNSCALED = 2
@ -139,6 +138,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if self._use_memory_tracer:
GLOBAL_MODEL_DATA_TRACER.register_optimizer(self)
self._use_memory_tracer = self.model.use_memory_tracer
if self._use_memory_tracer:
GLOBAL_MODEL_DATA_TRACER.register_optimizer(self)
def get_memory_usage(self) -> Tuple[int, int]:
"""
Get the memory usage of the optimizer. Including master_params (param fp32),
@ -190,12 +193,14 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Now p.data is sharded
# So optimizer states are sharded naturally
self._logger.debug(f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
self._logger.debug(
f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
ranks=[0])
ret = self.optim.step(*args, **kwargs)
self._logger.debug(f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
self._logger.debug(
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
ranks=[0])
# Copy master param data (fp32) to payload of col_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering

Loading…
Cancel
Save