mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] add correct device for fake_param (#2796)
parent
a619a190df
commit
56ddc9ca7a
|
@ -136,7 +136,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
assert fake_param.grad is None
|
||||
fake_param.data = none_tensor
|
||||
fake_param.data = none_tensor.to(fake_param.device)
|
||||
|
||||
for chunk16 in self.chunk16_set:
|
||||
chunk16.optim_update()
|
||||
|
@ -307,7 +307,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
if range_pair[0] >= range_pair[1]:
|
||||
continue
|
||||
|
||||
fake_param = torch.nn.Parameter(torch.empty([0]))
|
||||
grad_device = self.module.grads_device[param]
|
||||
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
|
||||
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
|
||||
self.param_to_range[fake_param] = range_pair
|
||||
|
||||
|
|
|
@ -70,8 +70,6 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
|
|||
for n, m in v.items():
|
||||
if isinstance(m, torch.Tensor):
|
||||
o = w[n]
|
||||
if m.device != o.device:
|
||||
o = o.to(m.device)
|
||||
assert torch.equal(m, o)
|
||||
else:
|
||||
assert m == w[n]
|
||||
|
|
Loading…
Reference in New Issue