diff --git a/colossalai/nn/optimizer/zero_optimizer.py b/colossalai/nn/optimizer/zero_optimizer.py index 402e28ce8..712daed06 100644 --- a/colossalai/nn/optimizer/zero_optimizer.py +++ b/colossalai/nn/optimizer/zero_optimizer.py @@ -136,7 +136,7 @@ class ZeroOptimizer(ColossalaiOptimizer): for group in self.param_groups: for fake_param in group['params']: assert fake_param.grad is None - fake_param.data = none_tensor + fake_param.data = none_tensor.to(fake_param.device) for chunk16 in self.chunk16_set: chunk16.optim_update() @@ -307,7 +307,8 @@ class ZeroOptimizer(ColossalaiOptimizer): if range_pair[0] >= range_pair[1]: continue - fake_param = torch.nn.Parameter(torch.empty([0])) + grad_device = self.module.grads_device[param] + fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device)) self.param_to_chunk32[fake_param] = chunk16.paired_chunk self.param_to_range[fake_param] = range_pair diff --git a/tests/test_gemini/update/test_zerooptim_state_dict.py b/tests/test_gemini/update/test_zerooptim_state_dict.py index dc3dda9d6..fd13af6b2 100644 --- a/tests/test_gemini/update/test_zerooptim_state_dict.py +++ b/tests/test_gemini/update/test_zerooptim_state_dict.py @@ -70,8 +70,6 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered): for n, m in v.items(): if isinstance(m, torch.Tensor): o = w[n] - if m.device != o.device: - o = o.to(m.device) assert torch.equal(m, o) else: assert m == w[n]