|
|
@ -136,7 +136,7 @@ class ZeroOptimizer(ColossalaiOptimizer): |
|
|
|
for group in self.param_groups: |
|
|
|
for group in self.param_groups: |
|
|
|
for fake_param in group['params']: |
|
|
|
for fake_param in group['params']: |
|
|
|
assert fake_param.grad is None |
|
|
|
assert fake_param.grad is None |
|
|
|
fake_param.data = none_tensor |
|
|
|
fake_param.data = none_tensor.to(fake_param.device) |
|
|
|
|
|
|
|
|
|
|
|
for chunk16 in self.chunk16_set: |
|
|
|
for chunk16 in self.chunk16_set: |
|
|
|
chunk16.optim_update() |
|
|
|
chunk16.optim_update() |
|
|
@ -307,7 +307,8 @@ class ZeroOptimizer(ColossalaiOptimizer): |
|
|
|
if range_pair[0] >= range_pair[1]: |
|
|
|
if range_pair[0] >= range_pair[1]: |
|
|
|
continue |
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
fake_param = torch.nn.Parameter(torch.empty([0])) |
|
|
|
grad_device = self.module.grads_device[param] |
|
|
|
|
|
|
|
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device)) |
|
|
|
self.param_to_chunk32[fake_param] = chunk16.paired_chunk |
|
|
|
self.param_to_chunk32[fake_param] = chunk16.paired_chunk |
|
|
|
self.param_to_range[fake_param] = range_pair |
|
|
|
self.param_to_range[fake_param] = range_pair |
|
|
|
|
|
|
|
|
|
|
|