diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 5e06eb646..9940ea5e5 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -76,7 +76,9 @@ class ShardedModelV2(nn.Module): fp32_reduce_scatter: bool = False, tensor_placement_policy: str = 'cuda', gradient_predivide_factor: Optional[float] = 1.0, - reuse_fp16_shard: bool = False): + reuse_fp16_shard: bool = False, + *args, + **kwargs): assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' super().__init__() self.logger = get_dist_logger() @@ -119,6 +121,14 @@ class ShardedModelV2(nn.Module): self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create( tensor_placement_policy)(mem_stats_collector=self._memstats_collector) + if 'warmup_non_model_data_ratio' in kwargs: + if tensor_placement_policy != 'auto': + self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement') + else: + ratio = kwargs['warmup_non_model_data_ratio'] + self._tensor_placement_policy._warmup_non_model_data_ratio = ratio + self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement') + self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy) param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')] self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list)