|
|
@ -439,7 +439,8 @@ class ShardedModelV2(nn.Module): |
|
|
|
for p in sharded_params: |
|
|
|
for p in sharded_params: |
|
|
|
p.data = p.colo_attr.data_payload |
|
|
|
p.data = p.colo_attr.data_payload |
|
|
|
module_to_load = module_to_load or self |
|
|
|
module_to_load = module_to_load or self |
|
|
|
gathered_state_dict = deepcopy(state_dict_func(module_to_load, destination, prefix, keep_vars)) |
|
|
|
gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars) |
|
|
|
|
|
|
|
gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()} |
|
|
|
if shard_strategy is not None: |
|
|
|
if shard_strategy is not None: |
|
|
|
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) |
|
|
|
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) |
|
|
|
for p in sharded_params: |
|
|
|
for p in sharded_params: |
|
|
|