From b87496a66b0f296411434ad3a6524e335a082725 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 20 Dec 2022 23:03:18 +0800 Subject: [PATCH] [hotfix] fix auto policy of test_sharded_optim_v2 (#2157) --- .../gemini/memory_tracer/chunk_memstats_collector.py | 2 +- colossalai/gemini/memory_tracer/memory_stats.py | 8 -------- colossalai/gemini/memory_tracer/memstats_collector.py | 4 +--- tests/test_zero/test_sharded_optim_v2.py | 2 +- 4 files changed, 3 insertions(+), 13 deletions(-) diff --git a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py index 44c11302e..1a5b6bf52 100644 --- a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py @@ -33,4 +33,4 @@ class ChunkMemStatsCollector(MemStatsCollector): @property def cuda_margin_mem(self) -> float: - return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda') + return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda diff --git a/colossalai/gemini/memory_tracer/memory_stats.py b/colossalai/gemini/memory_tracer/memory_stats.py index 0f8390e02..84fa00fb9 100644 --- a/colossalai/gemini/memory_tracer/memory_stats.py +++ b/colossalai/gemini/memory_tracer/memory_stats.py @@ -107,14 +107,6 @@ class MemStats(object): else: raise TypeError - def max_overall_cuda(self, device_type: str) -> float: - if device_type == 'cuda': - return max(self._overall_cuda_list) - elif device_type == 'cpu': - return max(self._overall_cpu_list) - else: - raise TypeError - def clear(self): self._model_data_cuda_list = [] self._overall_cuda_list = [] diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py index d521fe212..233fefcad 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -79,9 +79,7 @@ class MemStatsCollector: if self._start_flag and not self.use_outside_memstats: # The following code work for ZeroInitContext, which is deprecated in v0.1.12 cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] - cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] - self._memstats.append_model_data('cuda', cuda_mem) - self._memstats.append_model_data('cpu', cpu_mem) + self._memstats.record_max_cuda_model_data(cuda_mem) def sample_overall_data(self) -> None: """ diff --git a/tests/test_zero/test_sharded_optim_v2.py b/tests/test_zero/test_sharded_optim_v2.py index 221915167..8fe7eb639 100644 --- a/tests/test_zero/test_sharded_optim_v2.py +++ b/tests/test_zero/test_sharded_optim_v2.py @@ -64,7 +64,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g zero_model = ShardedModelV2( zero_model, shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'cuda', + tensor_placement_policy='cpu' if cpu_offload else 'auto', reuse_fp16_shard=use_cpuadam, )