[hotfix] fix auto policy of test_sharded_optim_v2 (#2157)

pull/2144/head^2
Jiarui Fang 2022-12-20 23:03:18 +08:00 committed by GitHub
parent 16335cb537
commit b87496a66b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 3 additions and 13 deletions

View File

@ -33,4 +33,4 @@ class ChunkMemStatsCollector(MemStatsCollector):
@property
def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda')
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda

View File

@ -107,14 +107,6 @@ class MemStats(object):
else:
raise TypeError
def max_overall_cuda(self, device_type: str) -> float:
if device_type == 'cuda':
return max(self._overall_cuda_list)
elif device_type == 'cpu':
return max(self._overall_cpu_list)
else:
raise TypeError
def clear(self):
self._model_data_cuda_list = []
self._overall_cuda_list = []

View File

@ -79,9 +79,7 @@ class MemStatsCollector:
if self._start_flag and not self.use_outside_memstats:
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
self._memstats.append_model_data('cuda', cuda_mem)
self._memstats.append_model_data('cpu', cpu_mem)
self._memstats.record_max_cuda_model_data(cuda_mem)
def sample_overall_data(self) -> None:
"""

View File

@ -64,7 +64,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
zero_model = ShardedModelV2(
zero_model,
shard_strategy,
tensor_placement_policy='cpu' if cpu_offload else 'cuda',
tensor_placement_policy='cpu' if cpu_offload else 'auto',
reuse_fp16_shard=use_cpuadam,
)