[hotfix] fix auto policy of test_sharded_optim_v2 (#2157)

pull/2144/head^2
Jiarui Fang 2022-12-20 23:03:18 +08:00 committed by GitHub
parent 16335cb537
commit b87496a66b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 3 additions and 13 deletions

View File

@ -33,4 +33,4 @@ class ChunkMemStatsCollector(MemStatsCollector):
@property @property
def cuda_margin_mem(self) -> float: def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda') return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda

View File

@ -107,14 +107,6 @@ class MemStats(object):
else: else:
raise TypeError raise TypeError
def max_overall_cuda(self, device_type: str) -> float:
if device_type == 'cuda':
return max(self._overall_cuda_list)
elif device_type == 'cpu':
return max(self._overall_cpu_list)
else:
raise TypeError
def clear(self): def clear(self):
self._model_data_cuda_list = [] self._model_data_cuda_list = []
self._overall_cuda_list = [] self._overall_cuda_list = []

View File

@ -79,9 +79,7 @@ class MemStatsCollector:
if self._start_flag and not self.use_outside_memstats: if self._start_flag and not self.use_outside_memstats:
# The following code work for ZeroInitContext, which is deprecated in v0.1.12 # The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] self._memstats.record_max_cuda_model_data(cuda_mem)
self._memstats.append_model_data('cuda', cuda_mem)
self._memstats.append_model_data('cpu', cpu_mem)
def sample_overall_data(self) -> None: def sample_overall_data(self) -> None:
""" """

View File

@ -64,7 +64,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
zero_model = ShardedModelV2( zero_model = ShardedModelV2(
zero_model, zero_model,
shard_strategy, shard_strategy,
tensor_placement_policy='cpu' if cpu_offload else 'cuda', tensor_placement_policy='cpu' if cpu_offload else 'auto',
reuse_fp16_shard=use_cpuadam, reuse_fp16_shard=use_cpuadam,
) )