diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 14b670a88..19f9c343b 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -29,7 +29,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer): def __init__(self, sharded_model: ShardedModelV2, optimizer_class: Type[Optimizer], - shard_strategy: BaseShardStrategy, cpu_offload: bool = False, initial_scale: float = 2**32, min_scale: float = 1, @@ -42,20 +41,43 @@ class ShardedOptimizerV2(ColossalaiOptimizer): mp_process_group: Optional[ProcessGroup] = None, **defaults: Any) -> None: """ - :param sharded_model: A sharded model initialized by class ShardedModelV2 + :param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the + shard strategy provided by sharded model to shard param fp32 tensors. :type sharded_model: sharded_model - :param optimizer_class: A type of Optimizer + :param optimizer_class: A class type of Optimizer :type optimizer_class: Type[Optimizer] - :param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters. - :type shard_strategy: BaseShardStrategy - :param cpu_offload: is offloading the optimizer states to CPU. :type cpu_offload: bool - :param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters. - :type shard_strategy: BaseShardStrategy + :param initial_scale: initial scale used by DynamicGradScaler + :type initial_scale: float + + :param min_scale: min scale used by DynamicGradScaler + :type min_scale: float + + :param growth_factor: growth_factor used by DynamicGradScaler + :type growth_factor: float + + :param backoff_factor: backoff_factor used by DynamicGradScaler + :type backoff_factor: float + + :param growth_interval: growth_interval used by DynamicGradScaler + :type growth_interval: float + + :param hysteresis: hysteresis used by DynamicGradScaler + :type hysteresis: float + + :param max_scale: max_scale used by DynamicGradScaler + :type max_scale: float + + :param dp_process_group: data paralle process group + :type dp_process_group: Optional[ProcessGroup] + + :param mp_process_group: model paralle process group + :type mp_process_group: Optional[ProcessGroup] + :**defaults: any trailing arguments, which are forwarded to the local optimizer. :type defaults: dict() """ @@ -67,7 +89,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self.optimizer = optimizer_class(sharded_model.parameters(), **self._optim_defaults) super().__init__(self.optimizer) - self.shard_strategy = shard_strategy + self.shard_strategy = sharded_model.shard_strategy self.model: ShardedModelV2 = sharded_model if cpu_offload and not sharded_model.cpu_offload: raise RuntimeError( diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index 622df5693..aeaa7afaf 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -52,7 +52,6 @@ def run_dist(rank, world_size, port, cpu_offload, shard_strategy): optim = optimizer_class(model.parameters(), lr=lr) sharded_optim = ShardedOptimizerV2(zero_model, optimizer_class, - shard_strategy, cpu_offload=cpu_offload, initial_scale=2**5, lr=lr) diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py b/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py index 942b46723..424ca3a65 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py @@ -59,12 +59,7 @@ def run_dist(rank, world_size, port, shard_strategy): if dist.get_world_size() > 1: model = DDP(model) optim = Adam(model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, - CPUAdam, - shard_strategy, - initial_scale=2**5, - cpu_offload=True, - lr=1e-3) + sharded_optim = ShardedOptimizerV2(zero_model, CPUAdam, initial_scale=2**5, cpu_offload=True, lr=1e-3) for i, (data, label) in enumerate(train_dataloader): if i > 2: break