fix master params dtype

pull/394/head
ver217 2022-03-03 15:50:30 +08:00 committed by Frank Lee
parent 795210dd99
commit 70814dc22f
1 changed files with 9 additions and 9 deletions

View File

@ -26,7 +26,7 @@ class ShardedAdam(ColossalaiOptimizer):
def __init__(self,
adam_optim: Optimizer,
sharded_model: nn.Module,
sharded_model: Union[nn.Module, ShardedModelV2],
cpu_offload: bool = False,
initial_scale: float = 2**32,
min_scale: float = 1,
@ -61,9 +61,11 @@ class ShardedAdam(ColossalaiOptimizer):
for p in group['params']:
if hasattr(p, 'ca_attr'):
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
self.master_params[p] = p.ca_attr.payload(self.device).to(torch.float)
self.master_params[p] = p.ca_attr.payload(self.device)
else:
self.master_params[p] = p.data.to(torch.float)
self.master_params[p] = p.data.to(device=self.device)
if torch.is_floating_point(self.master_params[p]) and self.master_params[p].dtype != torch.float:
self.master_params[p] = self.master_params[p].to(torch.float)
def step(self, *args, **kwargs):
# unscale grads if scaled
@ -85,8 +87,9 @@ class ShardedAdam(ColossalaiOptimizer):
# Write master param to payload and set p.data to None
for group in self.optim.param_groups:
for p in group['params']:
# TODO: update payload
p.data = None
if hasattr(p, 'ca_attr'):
# TODO: update payload
p.data = None
return ret
def backward(self, loss: Tensor) -> None:
@ -129,10 +132,7 @@ class ShardedAdam(ColossalaiOptimizer):
# all-reduce over model parallel group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.mp_process_group)
if self._found_overflow.item() > 0:
return True
else:
return False
return self._found_overflow.item() > 0
def _unscale_grads(self):
assert self.optim_state == OptimState.SCALED