[zero] refactory ShardedOptimV2 init method (#416)

pull/417/head
Jiarui Fang 2022-03-15 10:45:55 +08:00 committed by GitHub
parent e79ea44247
commit 23ba3fc450
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 16 deletions

View File

@ -29,7 +29,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def __init__(self,
sharded_model: ShardedModelV2,
optimizer_class: Type[Optimizer],
shard_strategy: BaseShardStrategy,
cpu_offload: bool = False,
initial_scale: float = 2**32,
min_scale: float = 1,
@ -42,20 +41,43 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
mp_process_group: Optional[ProcessGroup] = None,
**defaults: Any) -> None:
"""
:param sharded_model: A sharded model initialized by class ShardedModelV2
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
shard strategy provided by sharded model to shard param fp32 tensors.
:type sharded_model: sharded_model
:param optimizer_class: A type of Optimizer
:param optimizer_class: A class type of Optimizer
:type optimizer_class: Type[Optimizer]
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:param initial_scale: initial scale used by DynamicGradScaler
:type initial_scale: float
:param min_scale: min scale used by DynamicGradScaler
:type min_scale: float
:param growth_factor: growth_factor used by DynamicGradScaler
:type growth_factor: float
:param backoff_factor: backoff_factor used by DynamicGradScaler
:type backoff_factor: float
:param growth_interval: growth_interval used by DynamicGradScaler
:type growth_interval: float
:param hysteresis: hysteresis used by DynamicGradScaler
:type hysteresis: float
:param max_scale: max_scale used by DynamicGradScaler
:type max_scale: float
:param dp_process_group: data paralle process group
:type dp_process_group: Optional[ProcessGroup]
:param mp_process_group: model paralle process group
:type mp_process_group: Optional[ProcessGroup]
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
:type defaults: dict()
"""
@ -67,7 +89,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.optimizer = optimizer_class(sharded_model.parameters(), **self._optim_defaults)
super().__init__(self.optimizer)
self.shard_strategy = shard_strategy
self.shard_strategy = sharded_model.shard_strategy
self.model: ShardedModelV2 = sharded_model
if cpu_offload and not sharded_model.cpu_offload:
raise RuntimeError(

View File

@ -52,7 +52,6 @@ def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
optim = optimizer_class(model.parameters(), lr=lr)
sharded_optim = ShardedOptimizerV2(zero_model,
optimizer_class,
shard_strategy,
cpu_offload=cpu_offload,
initial_scale=2**5,
lr=lr)

View File

@ -59,12 +59,7 @@ def run_dist(rank, world_size, port, shard_strategy):
if dist.get_world_size() > 1:
model = DDP(model)
optim = Adam(model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(zero_model,
CPUAdam,
shard_strategy,
initial_scale=2**5,
cpu_offload=True,
lr=1e-3)
sharded_optim = ShardedOptimizerV2(zero_model, CPUAdam, initial_scale=2**5, cpu_offload=True, lr=1e-3)
for i, (data, label) in enumerate(train_dataloader):
if i > 2:
break