diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index f470671fc..5e46d215a 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -5,7 +5,7 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank, - sync_model_param) + sync_model_param, disposable) from .data_sampler import DataParallelSampler, get_dataloader from .gradient_accumulation import accumulate_gradient from .memory_utils.memory_monitor import report_memory_usage @@ -19,5 +19,5 @@ __all__ = [ 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', 'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint', - 'ensure_path_exists' + 'ensure_path_exists', 'disposable' ] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 2eb96cdfd..988dfc90b 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -4,8 +4,8 @@ import os import random import socket from pathlib import Path -from typing import List, Union - +from typing import Callable, List, Union +import functools import torch from torch._six import inf from torch.nn.parameter import Parameter @@ -112,6 +112,7 @@ def conditional_context(context_manager, enable=True): class model_branch_context(object): + def __enter__(self): self.env_status = env.save() @@ -131,7 +132,7 @@ def _calc_l2_norm(grads): colossal_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], - False # no per-parameter norm + False # no per-parameter norm ) return norm @@ -328,3 +329,16 @@ def switch_virtual_pipeline_parallel_rank(rank): yield finally: gpc.set_virtual_pipeline_parallel_rank(prev_rank) + + +def disposable(func: Callable) -> Callable: + executed = False + + @functools.wraps(func) + def wrapper(*args, **kwargs): + nonlocal executed + if not executed: + executed = True + return func(*args, **kwargs) + + return wrapper diff --git a/colossalai/utils/memory_tracer/memstats_collector.py b/colossalai/utils/memory_tracer/memstats_collector.py index 72c41b470..1c93998cc 100644 --- a/colossalai/utils/memory_tracer/memstats_collector.py +++ b/colossalai/utils/memory_tracer/memstats_collector.py @@ -1,36 +1,11 @@ from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_utils.utils import colo_device_memory_used -from colossalai.utils import get_current_device from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor import torch import time from typing import List -class SamplingCounter: - - def __init__(self) -> None: - self._samplint_cnt = 0 - self._max_sampling_cnt = None - - def advance(self): - self._samplint_cnt += 1 - - def next(self): - assert self._max_sampling_cnt is not None - return (self._samplint_cnt + 1) % self._max_sampling_cnt - - def current(self): - return self._samplint_cnt - - def max(self): - return self._max_sampling_cnt - - def reset(self): - self._max_sampling_cnt = self._samplint_cnt - self._samplint_cnt = 0 - - class MemStatsCollector: """ A Memory statistic collector. @@ -44,7 +19,6 @@ class MemStatsCollector: """ def __init__(self) -> None: - self._sampling_cnter = SamplingCounter() self._mem_monitor = AsyncMemoryMonitor() self._model_data_cuda_list = [] self._overall_cuda_list = [] @@ -57,6 +31,7 @@ class MemStatsCollector: self._sampling_time = [] self._start_flag = False + self._period_idx = 0 def overall_mem_stats(self, device_type: str): if device_type == 'cuda': @@ -106,15 +81,22 @@ class MemStatsCollector: else: raise TypeError - def current_non_model_data(self, device_type: str) -> int: - """get the non model data of the current sampling moment - """ - return self.non_model_data_list(device_type)[self._sampling_cnter.current()] + def max_non_model_data(self, device_type: str) -> int: + """Get max non model data memory usage of current sampling period - def next_non_model_data(self, device_type: str): - """get the non model data of the next sampling moment + Args: + device_type (str): device type, can be 'cpu' or 'cuda'. + + Returns: + int: max non model data memory usage of current sampling period """ - return self.non_model_data_list(device_type)[self._sampling_cnter.next()] + assert not self._start_flag, 'Cannot get mem stats info during collection phase.' + assert len(self._sampling_time) > 0, 'Cannot get mem stats info before collection phase.' + next_period_idx = (self._period_idx + 1) % len(self._sampling_time) + current_non_model_data = self.non_model_data_list(device_type)[self._period_idx] + next_non_model_data = self.non_model_data_list(device_type)[next_period_idx] + self._period_idx = next_period_idx + return max(current_non_model_data, next_non_model_data) @property def sampling_time(self): @@ -126,6 +108,7 @@ class MemStatsCollector: def finish_collection(self): self._start_flag = False + self._mem_monitor.finish() def sample_memstats(self) -> None: """ @@ -134,8 +117,6 @@ class MemStatsCollector: Advance the sampling cnter. """ if self._start_flag: - sampling_cnt = self._sampling_cnter.current() - assert sampling_cnt == len(self._overall_cuda_list) self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage) self._overall_cuda_list.append(self._mem_monitor.finish()) self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1]) @@ -146,13 +127,6 @@ class MemStatsCollector: self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1]) self._sampling_time.append(time.time()) self._mem_monitor.start() - # TODO(ver217): refactor sampler - # print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}') - self._sampling_cnter.advance() - - def reset_sampling_cnter(self) -> None: - self._sampling_cnter.reset() - self._mem_monitor.finish() def clear(self) -> None: self._model_data_cuda_list = [] @@ -162,5 +136,4 @@ class MemStatsCollector: self._overall_cpu_list = [] self._start_flag = False - self._sampling_cnter.reset() - self._mem_monitor.finish() + self._period_idx = 0 diff --git a/colossalai/utils/memory_tracer/test_async_memtracer.py b/colossalai/utils/memory_tracer/test_async_memtracer.py deleted file mode 100644 index 06c4052bd..000000000 --- a/colossalai/utils/memory_tracer/test_async_memtracer.py +++ /dev/null @@ -1,16 +0,0 @@ -from async_memtracer import AsyncMemoryMonitor -import torch - -if __name__ == '__main__': - async_mem_monitor = AsyncMemoryMonitor() - input = torch.randn(2, 20).cuda() - OP1 = torch.nn.Linear(20, 30).cuda() - OP2 = torch.nn.Linear(30, 40).cuda() - - async_mem_monitor.start() - output = OP1(input) - async_mem_monitor.finish() - async_mem_monitor.start() - output = OP2(output) - async_mem_monitor.finish() - async_mem_monitor.save('log.pkl') diff --git a/colossalai/utils/memory_tracer/test_memstats_collector.py b/colossalai/utils/memory_tracer/test_memstats_collector.py deleted file mode 100644 index 660938b92..000000000 --- a/colossalai/utils/memory_tracer/test_memstats_collector.py +++ /dev/null @@ -1,37 +0,0 @@ -from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector -import torch - - -def test_mem_collector(): - collector = MemStatsCollector() - - collector.start_collection() - - a = torch.randn(10).cuda() - - # sampling at time 0 - collector.sample_memstats() - - m_a = torch.randn(10).cuda() - b = torch.randn(10).cuda() - - # sampling at time 1 - collector.sample_memstats() - - a = b - - # sampling at time 2 - collector.sample_memstats() - - collector.finish_collection() - collector.reset_sampling_cnter() - - # do nothing after collection, just advance sampling cnter - collector.sample_memstats() - collector.sample_memstats() - - print(collector.overall_mem_stats('cuda')) - - -if __name__ == '__main__': - test_mem_collector() diff --git a/colossalai/zero/shard_utils/stateful_tensor_mgr.py b/colossalai/zero/shard_utils/stateful_tensor_mgr.py index 3a14f5139..63a89fcdc 100644 --- a/colossalai/zero/shard_utils/stateful_tensor_mgr.py +++ b/colossalai/zero/shard_utils/stateful_tensor_mgr.py @@ -71,8 +71,7 @@ class StatefulTensorMgr(object): max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio else: # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. - max_cuda_non_model_data_per_period = max(self._mem_stats_collector.current_non_model_data('cuda'), - self._mem_stats_collector.next_non_model_data('cuda')) + max_cuda_non_model_data_per_period = self._mem_stats_collector.max_non_model_data('cuda') total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 199ec882e..bd3752482 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -12,7 +12,7 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.engine.gradient_handler.utils import bucket_allreduce from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device +from colossalai.utils import get_current_device, disposable from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.model_data_memtracer import \ GLOBAL_MODEL_DATA_TRACER @@ -112,10 +112,11 @@ class ShardedModelV2(nn.Module): for param in submodule.parameters(recurse=False): if hasattr(param, 'colo_attr'): self._stateful_tensor_mgr.register_stateful_param(param.colo_attr) + self._start_collect_memstats = disposable(self._memstats_collector.start_collection) + self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) else: self._memstats_collector = None self._stateful_tensor_mgr = None - self._iter_cnter = 0 # Register hooks self._ophook_list = [ @@ -188,9 +189,9 @@ class ShardedModelV2(nn.Module): f.write('\n') def _pre_forward_operations(self): - if self._iter_cnter == 0 and self._memstats_collector: - # the operation will affect the memory tracer behavior in ZeroHook - self._memstats_collector.start_collection() + # the operation will affect the memory tracer behavior in ZeroHook + if self._memstats_collector: + self._start_collect_memstats() for p in self.module.parameters(): if hasattr(p, 'colo_attr'): @@ -221,17 +222,14 @@ class ShardedModelV2(nn.Module): ophook.post_iter() def _update_memstats(self): - if self._iter_cnter == 0 and self._memstats_collector: - self._memstats_collector.finish_collection() if self._memstats_collector: - self._memstats_collector.reset_sampling_cnter() + self._finish_collect_memstats() # cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used. # the way to calculate margin space is based on the assumption that # model data is fixed in cuda during training. # cuda margin space can be used to store OS. self._cuda_margin_space = colo_cuda_memory_capacity() - max( self._memstats_collector.overall_mem_stats('cuda')) - self._iter_cnter += 1 @torch.no_grad() def _post_backward_operations(self) -> None: diff --git a/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py b/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py index b77d02e94..857dad5ea 100644 --- a/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py +++ b/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py @@ -55,7 +55,6 @@ def run_stm(): apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr) mem_collector.sample_memstats() mem_collector.finish_collection() - mem_collector.reset_sampling_cnter() stateful_tensor_mgr.reset() # warmup done