mirror of https://github.com/hpcaitech/ColossalAI
polish code
parent
36f9a74ab2
commit
3092317b80
|
@ -25,7 +25,7 @@ class OptimState(Enum):
|
|||
class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
|
||||
def __init__(self,
|
||||
adam_optim: Optimizer,
|
||||
optimizer: Optimizer,
|
||||
sharded_model: Union[nn.Module, ShardedModelV2],
|
||||
cpu_offload: bool = False,
|
||||
initial_scale: float = 2**32,
|
||||
|
@ -37,7 +37,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
max_scale: int = 2**32,
|
||||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
||||
super().__init__(adam_optim)
|
||||
super().__init__(optimizer)
|
||||
self.model: Union[nn.Module, ShardedModelV2] = sharded_model
|
||||
self.model_is_sharded = isinstance(sharded_model, ShardedModelV2)
|
||||
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
|
||||
|
@ -57,7 +57,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# Store fp32 params
|
||||
self.master_params: Dict[Parameter, Tensor] = {}
|
||||
|
||||
for group in adam_optim.param_groups:
|
||||
for group in optimizer.param_groups:
|
||||
for p in group['params']:
|
||||
if hasattr(p, 'ca_attr'):
|
||||
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
|
||||
|
|
Loading…
Reference in New Issue