Browse Source

[hotfix] shared model returns cpu state_dict (#1328)

pull/1329/head
ver217 2 years ago committed by GitHub
parent
commit
7a05367101
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      colossalai/zero/sharded_model/sharded_model_v2.py
  2. 2
      tests/test_zero/test_state_dict.py

3
colossalai/zero/sharded_model/sharded_model_v2.py

@ -439,7 +439,8 @@ class ShardedModelV2(nn.Module):
for p in sharded_params: for p in sharded_params:
p.data = p.colo_attr.data_payload p.data = p.colo_attr.data_payload
module_to_load = module_to_load or self module_to_load = module_to_load or self
gathered_state_dict = deepcopy(state_dict_func(module_to_load, destination, prefix, keep_vars)) gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars)
gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()}
if shard_strategy is not None: if shard_strategy is not None:
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
for p in sharded_params: for p in sharded_params:

2
tests/test_zero/test_state_dict.py

@ -39,7 +39,7 @@ def run_zero_state_dict(shard_strategy_class):
zero_state_dict = zero_model.state_dict() zero_state_dict = zero_model.state_dict()
for key, val in model.state_dict().items(): for key, val in model.state_dict().items():
assert torch.equal(val, zero_state_dict[key]) assert torch.equal(val, zero_state_dict[key].to(val.device))
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):

Loading…
Cancel
Save