diff --git a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py index 4fbc1a477..3ce2f4d55 100644 --- a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py @@ -11,15 +11,16 @@ class ChunkMemStatsCollector(MemStatsCollector): super().__init__() self._chunk_manager = chunk_manager + # override def sample_model_data(self) -> None: """Sampling model data statistics. """ if self._start_flag: cuda_mem = self._chunk_manager.total_mem['cuda'] cpu_mem = self._chunk_manager.total_mem['cpu'] - self._model_data_cuda_list.append(cuda_mem) - self._model_data_cpu_list.append(cpu_mem) + self._memstats.append_model_data('cuda', cuda_mem) + self._memstats.append_model_data('cpu', cpu_mem) @property def cuda_margin_mem(self) -> float: - return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda')) + return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda') diff --git a/colossalai/gemini/memory_tracer/memory_stats.py b/colossalai/gemini/memory_tracer/memory_stats.py new file mode 100644 index 000000000..2bb859683 --- /dev/null +++ b/colossalai/gemini/memory_tracer/memory_stats.py @@ -0,0 +1,94 @@ +from typing import Any, Dict, List + + +class MemStats(object): + + def __init__(self) -> None: + """ + Store the non model data statistics used for Gemini and ZeroOptimizer. + """ + # p -> list of non_model data volumn visied in order. + self.param_non_model_data_map: Dict(Any, List[int]) = {} + + self._model_data_cuda_list = [] + self._model_data_cpu_list = [] + + self._overall_cuda_list = [] + self._overall_cpu_list = [] + + self._non_model_data_cuda_list = [] + self._non_model_data_cpu_list = [] + + def append_overall_data(self, device_type: str, val: float): + if device_type == 'cuda': + self._overall_cuda_list.append(val) + elif device_type == 'cpu': + self._overall_cpu_list.append(val) + else: + raise TypeError + + def append_model_data(self, device_type: str, val: float): + if device_type == 'cuda': + self._model_data_cuda_list.append(val) + elif device_type == 'cpu': + self._model_data_cpu_list.append(val) + else: + raise TypeError + + def append_non_model_data(self, device_type: str): + if device_type == 'cuda': + self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1]) + elif device_type == 'cpu': + self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1]) + else: + raise TypeError + + def overall_mem_stats(self, device_type: str) -> List[int]: + if device_type == 'cuda': + return self._overall_cuda_list + elif device_type == 'cpu': + return self._overall_cpu_list + else: + raise TypeError + + def model_data_list(self, device_type: str) -> List[int]: + if device_type == 'cuda': + return self._model_data_cuda_list + elif device_type == 'cpu': + return self._model_data_cpu_list + else: + raise TypeError + + def non_model_data_list(self, device_type: str) -> List[int]: + if device_type == 'cuda': + return self._non_model_data_cuda_list + elif device_type == 'cpu': + return self._non_model_data_cpu_list + else: + raise TypeError + + def max_non_model_data(self, device_type: str) -> float: + if device_type == 'cuda': + return max(self._non_model_data_cuda_list) + elif device_type == 'cpu': + return max(self._non_model_data_cpu_list) + else: + raise TypeError + + def max_overall_cuda(self, device_type: str) -> float: + if device_type == 'cuda': + return max(self._overall_cuda_list) + elif device_type == 'cpu': + return max(self._overall_cpu_list) + else: + raise TypeError + + def clear(self): + self._model_data_cuda_list = [] + self._overall_cuda_list = [] + + self._model_data_cpu_list = [] + self._overall_cpu_list = [] + + self._non_model_data_cpu_list = [] + self._non_model_data_cuda_list = [] diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py index 5074f3f32..6f0d8b271 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -7,6 +7,8 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.utils.memory import colo_device_memory_used +from .memory_stats import MemStats + class MemStatsCollector: """ @@ -22,43 +24,12 @@ class MemStatsCollector: def __init__(self) -> None: self._mem_monitor = SyncCudaMemoryMonitor() - self._model_data_cuda_list = [] - self._overall_cuda_list = [] - - self._model_data_cpu_list = [] - self._overall_cpu_list = [] - - self._non_model_data_cuda_list = [] - self._non_model_data_cpu_list = [] self._sampling_time = [] self._start_flag = False self._step_idx = 0 self._step_total = 0 - - def overall_mem_stats(self, device_type: str) -> List[int]: - if device_type == 'cuda': - return self._overall_cuda_list - elif device_type == 'cpu': - return self._overall_cpu_list - else: - raise TypeError - - def model_data_list(self, device_type: str) -> List[int]: - if device_type == 'cuda': - return self._model_data_cuda_list - elif device_type == 'cpu': - return self._model_data_cpu_list - else: - raise TypeError - - def non_model_data_list(self, device_type: str) -> List[int]: - if device_type == 'cuda': - return self._non_model_data_cuda_list - elif device_type == 'cpu': - return self._non_model_data_cpu_list - else: - raise TypeError + self._memstats = MemStats() def next_period_non_model_data_usage(self, device_type: str) -> int: """Get max non model data memory usage of current sampling period @@ -71,7 +42,7 @@ class MemStatsCollector: """ assert not self._start_flag, 'Cannot get mem stats info during collection phase.' assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' - next_non_model_data = self.non_model_data_list(device_type)[self._step_idx] + next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx] self._step_idx = (self._step_idx + 1) % self._step_total return next_non_model_data @@ -95,37 +66,29 @@ class MemStatsCollector: if self._start_flag: cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] - self._model_data_cuda_list.append(cuda_mem) - self._model_data_cpu_list.append(cpu_mem) + self._memstats.append_model_data('cuda', cuda_mem) + self._memstats.append_model_data('cpu', cpu_mem) def sample_overall_data(self) -> None: """Sampling non model data statistics. """ if self._start_flag: # overall data recording is after model data recording - if len(self._model_data_cuda_list) == 0: + if len(self._memstats._model_data_cuda_list) == 0: return - self._overall_cuda_list.append(self._mem_monitor.finish()) - self._overall_cpu_list.append(colo_device_memory_used(torch.device('cpu'))) + self._memstats.append_overall_data('cuda', self._mem_monitor.finish()) + self._memstats.append_overall_data('cpu', colo_device_memory_used(torch.device('cpu'))) - assert len(self._model_data_cuda_list) == len(self._overall_cuda_list) + assert len(self._memstats._model_data_cuda_list) == len(self._memstats._overall_cuda_list) - self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1]) - self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1]) + self._memstats.append_non_model_data('cuda') + self._memstats.append_non_model_data('cpu') self._sampling_time.append(time.time()) self._mem_monitor.start() def clear(self) -> None: - self._model_data_cuda_list = [] - self._overall_cuda_list = [] - - self._model_data_cpu_list = [] - self._overall_cpu_list = [] - - self._non_model_data_cpu_list = [] - self._non_model_data_cuda_list = [] - + self._memstats.clear() self._start_flag = False self._step_idx = 0 self._step_total = 0 diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index bbc2b1d25..47487ef15 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -85,7 +85,6 @@ class ShardedModelV2(nn.Module): tensor_placement_policy: str = 'cuda', gradient_predivide_factor: Optional[float] = 1.0, reuse_fp16_shard: bool = False, - user_static_memstats: bool = False, *args, **kwargs): assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' @@ -119,14 +118,10 @@ class ShardedModelV2(nn.Module): self.world_size = dist.get_world_size(self.process_group) self.rank = dist.get_rank(self.process_group) self.shard_strategy = shard_strategy - self.user_static_memstats = user_static_memstats self._use_memory_tracer = tensor_placement_policy == 'auto' if self._use_memory_tracer: - if self.user_static_memstats: - self._memstats_collector = StaticMemStatsCollector(self.module) - else: - self._memstats_collector = MemStatsCollector() + self._memstats_collector = MemStatsCollector() self._start_collect_memstats = disposable(self._memstats_collector.start_collection) self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) else: @@ -211,19 +206,17 @@ class ShardedModelV2(nn.Module): f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n') f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n') f.write('CUDA model data (GB)\n') - f.write(str(self._memstats_collector.model_data_list('cuda', 'GB'))) + f.write(str(self._memstats_collector._memstats.model_data_list('cuda'))) f.write('\n') f.write('CUDA non model data (GB)\n') - f.write(str(self._memstats_collector.non_model_data_list('cuda', 'GB'))) + f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda'))) f.write('CPU non model data (GB)\n') - f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB'))) + f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu'))) f.write('\n') def _pre_forward_operations(self, *args): # the operation will affect the memory tracer behavior in ZeroHook if self._memstats_collector: - if self.user_static_memstats: - self.init_mem_stats(*args) self._start_collect_memstats() for p in self.module.parameters(): @@ -264,7 +257,7 @@ class ShardedModelV2(nn.Module): # model data is fixed in cuda during training. # cuda margin space can be used to store OS. self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max( - self._memstats_collector.overall_mem_stats('cuda')) + self._memstats_collector._memstats.overall_mem_stats('cuda')) @torch.no_grad() def _post_backward_operations(self) -> None: diff --git a/tests/test_zero/test_mem_collector.py b/tests/test_zero/test_mem_collector.py index bea971935..eea0a04a0 100644 --- a/tests/test_zero/test_mem_collector.py +++ b/tests/test_zero/test_mem_collector.py @@ -1,74 +1,77 @@ -import torch -import colossalai -import pytest -import torch.multiprocessing as mp -import torch.nn as nn -import torch.nn.functional as F -from colossalai.utils.cuda import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.shard_utils import BucketTensorShardStrategy -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from functools import partial - - -class MyTestModel(torch.nn.Module): - - def __init__(self) -> None: - super().__init__() - self.proj1 = nn.Linear(512, 512) - self.weight = nn.Parameter(torch.randn(1024, 512)) - self.proj2 = nn.Linear(1024, 512) - - def forward(self, x): - x = self.proj1(x) - x = F.linear(x, self.weight) - x = self.proj2(x) - - return x - - -def run_mem_collector_testing(): - cuda_capacity = colo_device_memory_capacity(get_current_device()) - fraction = (50 * 1024**2) / cuda_capacity - # limit max memory to 50MB - colo_set_process_memory_fraction(fraction) - shard_strategy = BucketTensorShardStrategy() - with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True): - model = MyTestModel() - - model = ShardedModelV2(module=model, - shard_strategy=shard_strategy, - reduce_scatter_bucket_size_mb=1, - tensor_placement_policy='auto') - - data = torch.randn(2, 512, device=get_current_device()) - - output = model(data) - loss = torch.mean(output) - model.backward(loss) - - cuda_model_data_list = model._memstats_collector.model_data_list('cuda') - assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032] - - cuda_non_model_data_list = model._memstats_collector.non_model_data_list('cuda') - assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1] - assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1] - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_mem_collector_testing() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_mem_collector(world_size=2): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_mem_collector() +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import BucketTensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 + + +class MyTestModel(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.proj1 = nn.Linear(512, 512) + self.weight = nn.Parameter(torch.randn(1024, 512)) + self.proj2 = nn.Linear(1024, 512) + + def forward(self, x): + x = self.proj1(x) + x = F.linear(x, self.weight) + x = self.proj2(x) + + return x + + +def run_mem_collector_testing(): + cuda_capacity = colo_device_memory_capacity(get_current_device()) + fraction = (50 * 1024**2) / cuda_capacity + # limit max memory to 50MB + colo_set_process_memory_fraction(fraction) + shard_strategy = BucketTensorShardStrategy() + with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True): + model = MyTestModel() + + model = ShardedModelV2(module=model, + shard_strategy=shard_strategy, + reduce_scatter_bucket_size_mb=1, + tensor_placement_policy='auto') + + data = torch.randn(2, 512, device=get_current_device()) + + output = model(data) + loss = torch.mean(output) + model.backward(loss) + + cuda_model_data_list = model._memstats_collector._memstats.model_data_list('cuda') + assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032] + + cuda_non_model_data_list = model._memstats_collector._memstats.non_model_data_list('cuda') + print('cuda_non_model_data_list ', cuda_non_model_data_list) + assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1] + assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1] + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_mem_collector_testing() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_mem_collector(world_size=2): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_mem_collector()