From 70814dc22f1ec2708a3e0ee663ca1b12ee64361f Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 3 Mar 2022 15:50:30 +0800 Subject: [PATCH] fix master params dtype --- colossalai/zero/sharded_optim/sharded_adam.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/colossalai/zero/sharded_optim/sharded_adam.py b/colossalai/zero/sharded_optim/sharded_adam.py index 1cb8c4a1d..52b0bf5f7 100644 --- a/colossalai/zero/sharded_optim/sharded_adam.py +++ b/colossalai/zero/sharded_optim/sharded_adam.py @@ -26,7 +26,7 @@ class ShardedAdam(ColossalaiOptimizer): def __init__(self, adam_optim: Optimizer, - sharded_model: nn.Module, + sharded_model: Union[nn.Module, ShardedModelV2], cpu_offload: bool = False, initial_scale: float = 2**32, min_scale: float = 1, @@ -61,9 +61,11 @@ class ShardedAdam(ColossalaiOptimizer): for p in group['params']: if hasattr(p, 'ca_attr'): assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model' - self.master_params[p] = p.ca_attr.payload(self.device).to(torch.float) + self.master_params[p] = p.ca_attr.payload(self.device) else: - self.master_params[p] = p.data.to(torch.float) + self.master_params[p] = p.data.to(device=self.device) + if torch.is_floating_point(self.master_params[p]) and self.master_params[p].dtype != torch.float: + self.master_params[p] = self.master_params[p].to(torch.float) def step(self, *args, **kwargs): # unscale grads if scaled @@ -85,8 +87,9 @@ class ShardedAdam(ColossalaiOptimizer): # Write master param to payload and set p.data to None for group in self.optim.param_groups: for p in group['params']: - # TODO: update payload - p.data = None + if hasattr(p, 'ca_attr'): + # TODO: update payload + p.data = None return ret def backward(self, loss: Tensor) -> None: @@ -129,10 +132,7 @@ class ShardedAdam(ColossalaiOptimizer): # all-reduce over model parallel group dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.mp_process_group) - if self._found_overflow.item() > 0: - return True - else: - return False + return self._found_overflow.item() > 0 def _unscale_grads(self): assert self.optim_state == OptimState.SCALED