diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index a0214f609..7d5cfdae0 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -439,7 +439,8 @@ class ShardedModelV2(nn.Module): for p in sharded_params: p.data = p.colo_attr.data_payload module_to_load = module_to_load or self - gathered_state_dict = deepcopy(state_dict_func(module_to_load, destination, prefix, keep_vars)) + gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars) + gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()} if shard_strategy is not None: shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) for p in sharded_params: diff --git a/tests/test_zero/test_state_dict.py b/tests/test_zero/test_state_dict.py index 188bc5968..7ac9b151e 100644 --- a/tests/test_zero/test_state_dict.py +++ b/tests/test_zero/test_state_dict.py @@ -39,7 +39,7 @@ def run_zero_state_dict(shard_strategy_class): zero_state_dict = zero_model.state_dict() for key, val in model.state_dict().items(): - assert torch.equal(val, zero_state_dict[key]) + assert torch.equal(val, zero_state_dict[key].to(val.device)) def run_dist(rank, world_size, port):