From 1f9920582784994b6b3dd69f8440370641b830a5 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 6 Dec 2022 12:53:58 +0800 Subject: [PATCH] [Gemini] remove static tracer (#2083) --- colossalai/gemini/gemini_mgr.py | 24 +++---------------- .../memory_tracer/runtime_mem_tracer.py | 10 ++++++++ colossalai/nn/parallel/gemini_parallel.py | 2 +- tests/test_tensor/model/test_model.py | 2 +- 4 files changed, 15 insertions(+), 23 deletions(-) diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 781ffe771..317c4f15c 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -26,27 +26,13 @@ class GeminiManager: chunk_manager (ChunkManager): A ``ChunkManager`` instance. """ - def __init__(self, - placement_policy: str, - chunk_manager: ChunkManager, - module: Optional[torch.nn.Module] = None, - use_static_memstats: bool = False) -> None: + def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None: assert placement_policy in PlacementPolicyFactory.get_polocy_names() self.policy_name = placement_policy policy_cls = PlacementPolicyFactory.create(placement_policy) self._chunk_manager = chunk_manager - # self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None - self.use_static_memstats = use_static_memstats - if policy_cls.need_mem_stats: - if use_static_memstats: - assert module is not None - self._mem_stats_collector = StaticMemStatsCollector(module, chunk_manager) - else: - self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) - else: - self._mem_stats_collector = None - + self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 @@ -60,11 +46,7 @@ class GeminiManager: def pre_iter(self, *args): if self._mem_stats_collector and self._warmup: - if self.use_static_memstats: - self._mem_stats_collector.init_mem_stats(*args) - self._warmup = False - else: - self._mem_stats_collector.start_collection() + self._mem_stats_collector.start_collection() def post_iter(self): """This function must be called when each iteration finishes diff --git a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py index 277371a36..3b16686c7 100644 --- a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py @@ -9,6 +9,16 @@ __all__ = ['RuntimeMemTracer'] class RuntimeMemTracer(): + """RuntimeMemTracer for the module training using ColoParameter. + + Trace non-model memory usage during fwd+bwd process. + It is obtained by using a tensor with the same shape as the training process as the inputs + and running an single fwd+bwd to trace the statistics. + + NOTE() + 1. The premise to use this tracer is that the target DNN execute the same operations at each iterations, + 2. Module buffers are viewed as non-model data. + """ def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half): super().__init__() diff --git a/colossalai/nn/parallel/gemini_parallel.py b/colossalai/nn/parallel/gemini_parallel.py index 9f13cece2..bf11631f9 100644 --- a/colossalai/nn/parallel/gemini_parallel.py +++ b/colossalai/nn/parallel/gemini_parallel.py @@ -50,5 +50,5 @@ class GeminiDDP(ZeroDDP): hidden_dim=hidden_dim, search_range_mb=search_range_mb, min_chunk_size_mb=min_chunk_size_mb) - gemini_manager = GeminiManager(placement_policy, chunk_manager, module) + gemini_manager = GeminiManager(placement_policy, chunk_manager) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32) diff --git a/tests/test_tensor/model/test_model.py b/tests/test_tensor/model/test_model.py index 361fef8aa..3f53b94e0 100644 --- a/tests/test_tensor/model/test_model.py +++ b/tests/test_tensor/model/test_model.py @@ -117,7 +117,7 @@ def run_1d_hybrid_tp(model_name): else: output_torch = model_torch(data, label) loss_torch = output_torch - assert torch.allclose(loss, loss_torch, rtol=1e-2) + assert torch.allclose(loss, loss_torch, rtol=1e-2), f"model_name {model_name} failed" torch.distributed.barrier() loss.backward()