[hotfix] shared model returns cpu state_dict (#1328)

pull/1329/head
ver217 2 years ago committed by GitHub
parent b2475d8c5c
commit 7a05367101
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -439,7 +439,8 @@ class ShardedModelV2(nn.Module):
for p in sharded_params:
p.data = p.colo_attr.data_payload
module_to_load = module_to_load or self
gathered_state_dict = deepcopy(state_dict_func(module_to_load, destination, prefix, keep_vars))
gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars)
gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()}
if shard_strategy is not None:
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
for p in sharded_params:

@ -39,7 +39,7 @@ def run_zero_state_dict(shard_strategy_class):
zero_state_dict = zero_model.state_dict()
for key, val in model.state_dict().items():
assert torch.equal(val, zero_state_dict[key])
assert torch.equal(val, zero_state_dict[key].to(val.device))
def run_dist(rank, world_size, port):

Loading…
Cancel
Save