mirror of https://github.com/hpcaitech/ColossalAI
warmup ratio configration (#1192)
parent
dba7e0cfb4
commit
a444633d13
|
@ -76,7 +76,9 @@ class ShardedModelV2(nn.Module):
|
||||||
fp32_reduce_scatter: bool = False,
|
fp32_reduce_scatter: bool = False,
|
||||||
tensor_placement_policy: str = 'cuda',
|
tensor_placement_policy: str = 'cuda',
|
||||||
gradient_predivide_factor: Optional[float] = 1.0,
|
gradient_predivide_factor: Optional[float] = 1.0,
|
||||||
reuse_fp16_shard: bool = False):
|
reuse_fp16_shard: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
|
@ -119,6 +121,14 @@ class ShardedModelV2(nn.Module):
|
||||||
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
|
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
|
||||||
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
|
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
|
||||||
|
|
||||||
|
if 'warmup_non_model_data_ratio' in kwargs:
|
||||||
|
if tensor_placement_policy != 'auto':
|
||||||
|
self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement')
|
||||||
|
else:
|
||||||
|
ratio = kwargs['warmup_non_model_data_ratio']
|
||||||
|
self._tensor_placement_policy._warmup_non_model_data_ratio = ratio
|
||||||
|
self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement')
|
||||||
|
|
||||||
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
|
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
|
||||||
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')]
|
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')]
|
||||||
self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list)
|
self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list)
|
||||||
|
|
Loading…
Reference in New Issue