[hotfix] fix a bug in model data stats tracing (#655)

pull/666/head
Jiarui Fang 2022-04-03 21:48:06 +08:00 committed by GitHub
parent ade05a5d83
commit 0aab52301e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 12 deletions

View File

@ -1,7 +1,7 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
import torch
import time
from typing import List
@ -37,6 +37,7 @@ class MemStatsCollector:
def __init__(self) -> None:
self._sampling_cnter = SamplingCounter()
self._mem_monitor = AsyncMemoryMonitor()
self._model_data_cuda_list = []
self._overall_cuda_list = []
@ -101,6 +102,7 @@ class MemStatsCollector:
def start_collection(self):
self._start_flag = True
self._mem_monitor.start()
def finish_collection(self):
self._start_flag = False
@ -115,17 +117,20 @@ class MemStatsCollector:
sampling_cnt = self._sampling_cnter.sampling_cnt
assert sampling_cnt == len(self._overall_cuda_list)
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda_list.append(colo_device_memory_used(get_current_device()))
self._overall_cuda_list.append(self._mem_monitor.finish())
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
# FIXME() cpu sys used should also return from self._mem_monitor()
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
self._sampling_time.append(time.time())
self._mem_monitor.start()
self._sampling_cnter.advance()
def reset_sampling_cnter(self) -> None:
self._sampling_cnter.reset()
self._mem_monitor.finish()
def clear(self) -> None:
self._model_data_cuda_list = []
@ -136,3 +141,4 @@ class MemStatsCollector:
self._start_flag = False
self._sampling_cnter.reset()
self._mem_monitor.finish()

View File

@ -33,7 +33,7 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
def _get_tensor_mem_use(t: Optional[torch.Tensor]):
if t is None:
return
return 0, 0
assert isinstance(t, torch.Tensor)
_cpu_mem_usage, _cuda_mem_usage = 0, 0
if t.device.type == 'cpu':

View File

@ -139,10 +139,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if self._use_memory_tracer:
GLOBAL_MODEL_DATA_TRACER.register_optimizer(self)
self._use_memory_tracer = self.model.use_memory_tracer
if self._use_memory_tracer:
GLOBAL_MODEL_DATA_TRACER.register_optimizer(self)
def get_memory_usage(self) -> Tuple[int, int]:
""" Get the memory usage of the optimizer. Including master_params (param fp32),
momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``)
@ -186,7 +182,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._zero_grad(recover_data=True)
return
self._prepare_data()
self._point_param_fp16_to_master_param()
self._logger.debug(
f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
@ -197,7 +193,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._logger.debug(
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
ranks=[0])
self._write_back_data()
self._copy_master_param_to_param_fp16()
return ret
def backward(self, loss: Tensor) -> None:
@ -319,7 +315,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Set p.data to empty tensor, in case of memory leaking
p.colo_attr.remove_torch_payload()
def _prepare_data(self):
def _point_param_fp16_to_master_param(self):
# assign master param pointers to p.data.
# We will not trigger data copy here.
for group in self.optim.param_groups:
@ -329,7 +325,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Now p.data is sharded
# So optimizer states are sharded naturally
def _write_back_data(self):
def _copy_master_param_to_param_fp16(self):
# Copy master param data (fp32) to payload of colo_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.

View File

@ -91,6 +91,7 @@ def _run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.skip("Under development")
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_moe_zero_init(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())