mirror of https://github.com/hpcaitech/ColossalAI
warmup ratio configration (#1192)
parent
dba7e0cfb4
commit
a444633d13
|
@ -76,7 +76,9 @@ class ShardedModelV2(nn.Module):
|
|||
fp32_reduce_scatter: bool = False,
|
||||
tensor_placement_policy: str = 'cuda',
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
reuse_fp16_shard: bool = False):
|
||||
reuse_fp16_shard: bool = False,
|
||||
*args,
|
||||
**kwargs):
|
||||
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
||||
super().__init__()
|
||||
self.logger = get_dist_logger()
|
||||
|
@ -119,6 +121,14 @@ class ShardedModelV2(nn.Module):
|
|||
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
|
||||
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
|
||||
|
||||
if 'warmup_non_model_data_ratio' in kwargs:
|
||||
if tensor_placement_policy != 'auto':
|
||||
self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement')
|
||||
else:
|
||||
ratio = kwargs['warmup_non_model_data_ratio']
|
||||
self._tensor_placement_policy._warmup_non_model_data_ratio = ratio
|
||||
self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement')
|
||||
|
||||
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
|
||||
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')]
|
||||
self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list)
|
||||
|
|
Loading…
Reference in New Issue