From 6e553748a7e068c6b1232c5d31d3fc7b8b4b6dee Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 14 Apr 2022 21:03:59 +0800 Subject: [PATCH] polish sharded optim docstr and warning (#770) --- colossalai/zero/sharded_optim/sharded_optim_v2.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 680f86962..c4fbf1b7c 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -50,7 +50,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): You must use ``ShardedOptimizerV2`` with ``ShardedModelV2``. Note: - Make sure you enable ``use_memory_tracer`` in ``ShardedModelV2``, + Make sure you set ``tensor_placement_policy`` in ``ShardedModelV2`` to `"auto"`, if you set ``gpu_margin_mem_ratio > 0``. Args: @@ -59,7 +59,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer): optimizer (Optimizer): An Optimizer instance. gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) which will be used when using hybrid CPU optimizer. - Make sure `reuse_fp16_shard` is enabled in `ShardedModelV2`, if `gpu_margin_mem_ratio` > `0.0`. This argument is meaningless when `tensor_placement_policy` of `ShardedModelV2` is not "auto". Defaults to 0.0. initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. @@ -119,8 +118,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # Store fp32 param shards self._register_master_weight() - if self.gpu_margin_mem_ratio != 0.0 and isinstance(sharded_model._tensor_placement_policy, - AutoTensorPlacementPolicy): + if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy, + AutoTensorPlacementPolicy): self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"') self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0])