polish code

pull/394/head
ver217 2022-03-04 13:44:38 +08:00 committed by Frank Lee
parent 36f9a74ab2
commit 3092317b80
1 changed files with 3 additions and 3 deletions

View File

@ -25,7 +25,7 @@ class OptimState(Enum):
class ShardedOptimizerV2(ColossalaiOptimizer):
def __init__(self,
adam_optim: Optimizer,
optimizer: Optimizer,
sharded_model: Union[nn.Module, ShardedModelV2],
cpu_offload: bool = False,
initial_scale: float = 2**32,
@ -37,7 +37,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
max_scale: int = 2**32,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None:
super().__init__(adam_optim)
super().__init__(optimizer)
self.model: Union[nn.Module, ShardedModelV2] = sharded_model
self.model_is_sharded = isinstance(sharded_model, ShardedModelV2)
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
@ -57,7 +57,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Store fp32 params
self.master_params: Dict[Parameter, Tensor] = {}
for group in adam_optim.param_groups:
for group in optimizer.param_groups:
for p in group['params']:
if hasattr(p, 'ca_attr'):
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'