|
|
@ -136,7 +136,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|
|
|
for group in self.param_groups:
|
|
|
|
for group in self.param_groups:
|
|
|
|
for fake_param in group['params']:
|
|
|
|
for fake_param in group['params']:
|
|
|
|
assert fake_param.grad is None
|
|
|
|
assert fake_param.grad is None
|
|
|
|
fake_param.data = none_tensor
|
|
|
|
fake_param.data = none_tensor.to(fake_param.device)
|
|
|
|
|
|
|
|
|
|
|
|
for chunk16 in self.chunk16_set:
|
|
|
|
for chunk16 in self.chunk16_set:
|
|
|
|
chunk16.optim_update()
|
|
|
|
chunk16.optim_update()
|
|
|
@ -307,7 +307,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|
|
|
if range_pair[0] >= range_pair[1]:
|
|
|
|
if range_pair[0] >= range_pair[1]:
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
fake_param = torch.nn.Parameter(torch.empty([0]))
|
|
|
|
grad_device = self.module.grads_device[param]
|
|
|
|
|
|
|
|
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
|
|
|
|
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
|
|
|
|
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
|
|
|
|
self.param_to_range[fake_param] = range_pair
|
|
|
|
self.param_to_range[fake_param] = range_pair
|
|
|
|
|
|
|
|
|
|
|
|