mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] shared model returns cpu state_dict (#1328)
parent
b2475d8c5c
commit
7a05367101
|
@ -439,7 +439,8 @@ class ShardedModelV2(nn.Module):
|
|||
for p in sharded_params:
|
||||
p.data = p.colo_attr.data_payload
|
||||
module_to_load = module_to_load or self
|
||||
gathered_state_dict = deepcopy(state_dict_func(module_to_load, destination, prefix, keep_vars))
|
||||
gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars)
|
||||
gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()}
|
||||
if shard_strategy is not None:
|
||||
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
|
||||
for p in sharded_params:
|
||||
|
|
|
@ -39,7 +39,7 @@ def run_zero_state_dict(shard_strategy_class):
|
|||
|
||||
zero_state_dict = zero_model.state_dict()
|
||||
for key, val in model.state_dict().items():
|
||||
assert torch.equal(val, zero_state_dict[key])
|
||||
assert torch.equal(val, zero_state_dict[key].to(val.device))
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
Loading…
Reference in New Issue