warmup ratio configration (#1192)

pull/1194/head
Jiarui Fang 2022-06-30 15:23:50 +08:00 committed by GitHub
parent dba7e0cfb4
commit a444633d13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 1 deletions

View File

@ -76,7 +76,9 @@ class ShardedModelV2(nn.Module):
fp32_reduce_scatter: bool = False,
tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False):
reuse_fp16_shard: bool = False,
*args,
**kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
super().__init__()
self.logger = get_dist_logger()
@ -119,6 +121,14 @@ class ShardedModelV2(nn.Module):
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
if 'warmup_non_model_data_ratio' in kwargs:
if tensor_placement_policy != 'auto':
self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement')
else:
ratio = kwargs['warmup_non_model_data_ratio']
self._tensor_placement_policy._warmup_non_model_data_ratio = ratio
self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement')
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')]
self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list)