2023-04-04 05:48:16 +00:00
|
|
|
import copy
|
|
|
|
|
2022-03-10 01:57:26 +00:00
|
|
|
import torch
|
|
|
|
|
2023-09-18 08:31:06 +00:00
|
|
|
from colossalai.legacy.zero.sharded_model import ShardedModelV2
|
2022-03-10 01:57:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module):
|
|
|
|
"""
|
|
|
|
copy param of the ShardedModelV2 to other_model.
|
|
|
|
Note the other_model has to be the same as self.
|
|
|
|
"""
|
|
|
|
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
|
2023-09-19 06:20:26 +00:00
|
|
|
assert hasattr(zero_param, "colo_attr")
|
2022-03-31 04:25:45 +00:00
|
|
|
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
|
2022-03-10 01:57:26 +00:00
|
|
|
if shard_flag:
|
2022-03-31 04:25:45 +00:00
|
|
|
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
|
2022-04-13 06:54:26 +00:00
|
|
|
param.data = copy.deepcopy(zero_param.colo_attr.data_payload)
|
2022-03-10 01:57:26 +00:00
|
|
|
if shard_flag:
|
2022-03-31 04:25:45 +00:00
|
|
|
sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])
|