diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 5b4cb1122..ac8b60033 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -25,7 +25,7 @@ class OptimState(Enum): class ShardedOptimizerV2(ColossalaiOptimizer): def __init__(self, - adam_optim: Optimizer, + optimizer: Optimizer, sharded_model: Union[nn.Module, ShardedModelV2], cpu_offload: bool = False, initial_scale: float = 2**32, @@ -37,7 +37,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): max_scale: int = 2**32, dp_process_group: Optional[ProcessGroup] = None, mp_process_group: Optional[ProcessGroup] = None) -> None: - super().__init__(adam_optim) + super().__init__(optimizer) self.model: Union[nn.Module, ShardedModelV2] = sharded_model self.model_is_sharded = isinstance(sharded_model, ShardedModelV2) self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu') @@ -57,7 +57,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # Store fp32 params self.master_params: Dict[Parameter, Tensor] = {} - for group in adam_optim.param_groups: + for group in optimizer.param_groups: for p in group['params']: if hasattr(p, 'ca_attr'): assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'