From 85efb7ac2ee6b000aa403058e275d3da352d036e Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 7 Dec 2022 23:04:02 +0800 Subject: [PATCH] [Gemini] gemini use the runtime memory tracer (RMT) (#2099) --- colossalai/gemini/gemini_mgr.py | 8 +- colossalai/gemini/memory_tracer/__init__.py | 2 +- .../memory_tracer/chunk_memstats_collector.py | 9 +- .../memory_tracer/memstats_collector.py | 19 ++-- .../memory_tracer/runtime_mem_tracer.py | 3 + .../test_gemini/update/test_gemini_use_rmt.py | 92 +++++++++++++++++++ 6 files changed, 120 insertions(+), 13 deletions(-) create mode 100644 tests/test_gemini/update/test_gemini_use_rmt.py diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 317c4f15c..c3a813367 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -5,8 +5,9 @@ from typing import List, Optional, Tuple import torch from colossalai.gemini.chunk import Chunk, ChunkManager +from colossalai.gemini.memory_tracer import MemStats -from .memory_tracer import ChunkMemStatsCollector, StaticMemStatsCollector +from .memory_tracer import ChunkMemStatsCollector from .placement_policy import PlacementPolicyFactory @@ -26,13 +27,14 @@ class GeminiManager: chunk_manager (ChunkManager): A ``ChunkManager`` instance. """ - def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None: + def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: assert placement_policy in PlacementPolicyFactory.get_polocy_names() self.policy_name = placement_policy policy_cls = PlacementPolicyFactory.create(placement_policy) self._chunk_manager = chunk_manager - self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None + self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager, + memstats) if policy_cls.need_mem_stats else None self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 diff --git a/colossalai/gemini/memory_tracer/__init__.py b/colossalai/gemini/memory_tracer/__init__.py index b571e31b2..c7b7efad7 100644 --- a/colossalai/gemini/memory_tracer/__init__.py +++ b/colossalai/gemini/memory_tracer/__init__.py @@ -1,8 +1,8 @@ +from .memory_stats import MemStats # isort:skip from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip from .memstats_collector import MemStatsCollector # isort:skip from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip from .static_memstats_collector import StaticMemStatsCollector # isort:skip -from .memory_stats import MemStats __all__ = [ 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', diff --git a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py index 3ce2f4d55..6c681d31f 100644 --- a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py @@ -1,4 +1,7 @@ +from typing import Optional + from colossalai.gemini.chunk import ChunkManager +from colossalai.gemini.memory_tracer import MemStats from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_capacity @@ -7,15 +10,15 @@ from .memstats_collector import MemStatsCollector class ChunkMemStatsCollector(MemStatsCollector): - def __init__(self, chunk_manager: ChunkManager) -> None: - super().__init__() + def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: + super().__init__(memstats) self._chunk_manager = chunk_manager # override def sample_model_data(self) -> None: """Sampling model data statistics. """ - if self._start_flag: + if self._start_flag and not self.use_outside_memstats: cuda_mem = self._chunk_manager.total_mem['cuda'] cpu_mem = self._chunk_manager.total_mem['cpu'] self._memstats.append_model_data('cuda', cuda_mem) diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py index 6f0d8b271..7d034dd8f 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -1,5 +1,5 @@ import time -from typing import List +from typing import List, Optional import torch @@ -22,14 +22,19 @@ class MemStatsCollector: It has a Sampling counter which is reset after DNN training iteration. """ - def __init__(self) -> None: + def __init__(self, memstats: Optional[MemStats] = None) -> None: self._mem_monitor = SyncCudaMemoryMonitor() self._sampling_time = [] self._start_flag = False self._step_idx = 0 self._step_total = 0 - self._memstats = MemStats() + if memstats is not None: + self.use_outside_memstats = True + self._memstats = memstats + else: + self.use_outside_memstats = False + self._memstats = MemStats() def next_period_non_model_data_usage(self, device_type: str) -> int: """Get max non model data memory usage of current sampling period @@ -63,7 +68,7 @@ class MemStatsCollector: def sample_model_data(self) -> None: """Sampling model data statistics. """ - if self._start_flag: + if self._start_flag and not self.use_outside_memstats: cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] self._memstats.append_model_data('cuda', cuda_mem) @@ -72,7 +77,7 @@ class MemStatsCollector: def sample_overall_data(self) -> None: """Sampling non model data statistics. """ - if self._start_flag: + if self._start_flag and not self.use_outside_memstats: # overall data recording is after model data recording if len(self._memstats._model_data_cuda_list) == 0: return @@ -84,9 +89,11 @@ class MemStatsCollector: self._memstats.append_non_model_data('cuda') self._memstats.append_non_model_data('cpu') - self._sampling_time.append(time.time()) self._mem_monitor.start() + if self._start_flag: + self._sampling_time.append(time.time()) + def clear(self) -> None: self._memstats.clear() self._start_flag = False diff --git a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py index 724afcfe3..1090cf92c 100644 --- a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py @@ -35,6 +35,9 @@ class RuntimeMemTracer(): self._cast_buffers_to_cuda_dtype() + def memstats(self): + return self._memstats + def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) diff --git a/tests/test_gemini/update/test_gemini_use_rmt.py b/tests/test_gemini/update/test_gemini_use_rmt.py new file mode 100644 index 000000000..564dee005 --- /dev/null +++ b/tests/test_gemini/update/test_gemini_use_rmt.py @@ -0,0 +1,92 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed + +# run gemini use the runtime memory tracer + + +@parameterize('placement_policy', ['auto']) +@parameterize('keep_gather', [False]) +@parameterize('model_name', ['bert', 'albert', 'gpt2']) +@parameterize('use_grad_checkpoint', [False, True]) +def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device='cpu'): + model = model_builder(use_grad_checkpoint) + + print(f'model_name {model_name}') + runtime_mem_tracer = RuntimeMemTracer(model) + for i, (input_ids, label) in enumerate(train_dataloader): + if i > 0: + break + input_ids, label = input_ids.cuda(), label.cuda() + + # mem tracing + if i == 0: + run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) + memstats = runtime_mem_tracer.memstats() + runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list + print('runtime tracer: ', runtime_tracer_non_model_data) + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gather + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + pg = ProcessGroup() + set_seed(pg.dp_local_rank()) + for i, (input_ids, label) in enumerate(train_dataloader): + # you can only test a single fwd + bwd. + # after bwd param is grad for Gemini, due to the chunk reuse optimization. + if i > 1: + break + input_ids, label = input_ids.cuda(), label.cuda() + + set_seed(42) + loss = run_fwd_bwd(model, input_ids, label, criterion, model) + + gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') + + # print('gemini non model data:', gemini_non_model_data) + + assert len(gemini_non_model_data) == len(runtime_tracer_non_model_data), \ + f'model_name {model_name} {len(gemini_non_model_data)} vs {len(runtime_tracer_non_model_data)}' + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gemini_use_rmt() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gemini_use_rmt(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_gemini_use_rmt(1)