mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] shared model returns cpu state_dict (#1328)
parent
b2475d8c5c
commit
7a05367101
|
@ -439,7 +439,8 @@ class ShardedModelV2(nn.Module):
|
||||||
for p in sharded_params:
|
for p in sharded_params:
|
||||||
p.data = p.colo_attr.data_payload
|
p.data = p.colo_attr.data_payload
|
||||||
module_to_load = module_to_load or self
|
module_to_load = module_to_load or self
|
||||||
gathered_state_dict = deepcopy(state_dict_func(module_to_load, destination, prefix, keep_vars))
|
gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars)
|
||||||
|
gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()}
|
||||||
if shard_strategy is not None:
|
if shard_strategy is not None:
|
||||||
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
|
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
|
||||||
for p in sharded_params:
|
for p in sharded_params:
|
||||||
|
|
|
@ -39,7 +39,7 @@ def run_zero_state_dict(shard_strategy_class):
|
||||||
|
|
||||||
zero_state_dict = zero_model.state_dict()
|
zero_state_dict = zero_model.state_dict()
|
||||||
for key, val in model.state_dict().items():
|
for key, val in model.state_dict().items():
|
||||||
assert torch.equal(val, zero_state_dict[key])
|
assert torch.equal(val, zero_state_dict[key].to(val.device))
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
|
|
Loading…
Reference in New Issue