mirror of https://github.com/hpcaitech/ColossalAI
polish code
parent
36f9a74ab2
commit
3092317b80
|
@ -25,7 +25,7 @@ class OptimState(Enum):
|
||||||
class ShardedOptimizerV2(ColossalaiOptimizer):
|
class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
adam_optim: Optimizer,
|
optimizer: Optimizer,
|
||||||
sharded_model: Union[nn.Module, ShardedModelV2],
|
sharded_model: Union[nn.Module, ShardedModelV2],
|
||||||
cpu_offload: bool = False,
|
cpu_offload: bool = False,
|
||||||
initial_scale: float = 2**32,
|
initial_scale: float = 2**32,
|
||||||
|
@ -37,7 +37,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
max_scale: int = 2**32,
|
max_scale: int = 2**32,
|
||||||
dp_process_group: Optional[ProcessGroup] = None,
|
dp_process_group: Optional[ProcessGroup] = None,
|
||||||
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
||||||
super().__init__(adam_optim)
|
super().__init__(optimizer)
|
||||||
self.model: Union[nn.Module, ShardedModelV2] = sharded_model
|
self.model: Union[nn.Module, ShardedModelV2] = sharded_model
|
||||||
self.model_is_sharded = isinstance(sharded_model, ShardedModelV2)
|
self.model_is_sharded = isinstance(sharded_model, ShardedModelV2)
|
||||||
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
|
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
|
||||||
|
@ -57,7 +57,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
# Store fp32 params
|
# Store fp32 params
|
||||||
self.master_params: Dict[Parameter, Tensor] = {}
|
self.master_params: Dict[Parameter, Tensor] = {}
|
||||||
|
|
||||||
for group in adam_optim.param_groups:
|
for group in optimizer.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if hasattr(p, 'ca_attr'):
|
if hasattr(p, 'ca_attr'):
|
||||||
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
|
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
|
||||||
|
|
Loading…
Reference in New Issue