Browse Source

[hotfix] add correct device for fake_param (#2796)

pull/2797/head^2
HELSON 2 years ago committed by GitHub
parent
commit
56ddc9ca7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      colossalai/nn/optimizer/zero_optimizer.py
  2. 2
      tests/test_gemini/update/test_zerooptim_state_dict.py

5
colossalai/nn/optimizer/zero_optimizer.py

@ -136,7 +136,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
for group in self.param_groups: for group in self.param_groups:
for fake_param in group['params']: for fake_param in group['params']:
assert fake_param.grad is None assert fake_param.grad is None
fake_param.data = none_tensor fake_param.data = none_tensor.to(fake_param.device)
for chunk16 in self.chunk16_set: for chunk16 in self.chunk16_set:
chunk16.optim_update() chunk16.optim_update()
@ -307,7 +307,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
if range_pair[0] >= range_pair[1]: if range_pair[0] >= range_pair[1]:
continue continue
fake_param = torch.nn.Parameter(torch.empty([0])) grad_device = self.module.grads_device[param]
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
self.param_to_chunk32[fake_param] = chunk16.paired_chunk self.param_to_chunk32[fake_param] = chunk16.paired_chunk
self.param_to_range[fake_param] = range_pair self.param_to_range[fake_param] = range_pair

2
tests/test_gemini/update/test_zerooptim_state_dict.py

@ -70,8 +70,6 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
for n, m in v.items(): for n, m in v.items():
if isinstance(m, torch.Tensor): if isinstance(m, torch.Tensor):
o = w[n] o = w[n]
if m.device != o.device:
o = o.to(m.device)
assert torch.equal(m, o) assert torch.equal(m, o)
else: else:
assert m == w[n] assert m == w[n]

Loading…
Cancel
Save