mirror of https://github.com/hpcaitech/ColossalAI
fix master params dtype
parent
795210dd99
commit
70814dc22f
|
@ -26,7 +26,7 @@ class ShardedAdam(ColossalaiOptimizer):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
adam_optim: Optimizer,
|
adam_optim: Optimizer,
|
||||||
sharded_model: nn.Module,
|
sharded_model: Union[nn.Module, ShardedModelV2],
|
||||||
cpu_offload: bool = False,
|
cpu_offload: bool = False,
|
||||||
initial_scale: float = 2**32,
|
initial_scale: float = 2**32,
|
||||||
min_scale: float = 1,
|
min_scale: float = 1,
|
||||||
|
@ -61,9 +61,11 @@ class ShardedAdam(ColossalaiOptimizer):
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if hasattr(p, 'ca_attr'):
|
if hasattr(p, 'ca_attr'):
|
||||||
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
|
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
|
||||||
self.master_params[p] = p.ca_attr.payload(self.device).to(torch.float)
|
self.master_params[p] = p.ca_attr.payload(self.device)
|
||||||
else:
|
else:
|
||||||
self.master_params[p] = p.data.to(torch.float)
|
self.master_params[p] = p.data.to(device=self.device)
|
||||||
|
if torch.is_floating_point(self.master_params[p]) and self.master_params[p].dtype != torch.float:
|
||||||
|
self.master_params[p] = self.master_params[p].to(torch.float)
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
def step(self, *args, **kwargs):
|
||||||
# unscale grads if scaled
|
# unscale grads if scaled
|
||||||
|
@ -85,8 +87,9 @@ class ShardedAdam(ColossalaiOptimizer):
|
||||||
# Write master param to payload and set p.data to None
|
# Write master param to payload and set p.data to None
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
# TODO: update payload
|
if hasattr(p, 'ca_attr'):
|
||||||
p.data = None
|
# TODO: update payload
|
||||||
|
p.data = None
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def backward(self, loss: Tensor) -> None:
|
def backward(self, loss: Tensor) -> None:
|
||||||
|
@ -129,10 +132,7 @@ class ShardedAdam(ColossalaiOptimizer):
|
||||||
# all-reduce over model parallel group
|
# all-reduce over model parallel group
|
||||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.mp_process_group)
|
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.mp_process_group)
|
||||||
|
|
||||||
if self._found_overflow.item() > 0:
|
return self._found_overflow.item() > 0
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _unscale_grads(self):
|
def _unscale_grads(self):
|
||||||
assert self.optim_state == OptimState.SCALED
|
assert self.optim_state == OptimState.SCALED
|
||||||
|
|
Loading…
Reference in New Issue