2023-03-31 01:20:33 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
from .experts import MoeExperts
|
|
|
|
|
|
|
|
|
|
|
|
def save_moe_model(model: nn.Module, save_path: str):
|
|
|
|
state_dict = model.state_dict()
|
|
|
|
if dist.get_rank() == 0:
|
|
|
|
torch.save(state_dict, save_path)
|
|
|
|
dist.barrier()
|
|
|
|
|
|
|
|
|
|
|
|
def load_moe_model(model: nn.Module, load_path: str):
|
|
|
|
state_dict = torch.load(load_path)
|
|
|
|
|
|
|
|
for prefix, module in model.named_modules():
|
2023-09-19 06:20:26 +00:00
|
|
|
if prefix.endswith(".moe_layer.experts"):
|
2023-03-31 01:20:33 +00:00
|
|
|
# this module should be an Experts instance
|
|
|
|
assert isinstance(module, MoeExperts)
|
|
|
|
|
|
|
|
ep_rank = dist.get_rank(module.dist_info.ep_group)
|
|
|
|
num_local = module.num_local_experts
|
|
|
|
for i in range(num_local):
|
|
|
|
expert_id = ep_rank * num_local + i
|
|
|
|
for name, _ in module.experts[i].named_parameters():
|
2023-09-19 06:20:26 +00:00
|
|
|
cur_key = f"{prefix}.experts.{i}.{name}"
|
|
|
|
param_key = f"{prefix}.experts.{expert_id}.{name}"
|
2023-03-31 01:20:33 +00:00
|
|
|
load_param = state_dict[param_key]
|
|
|
|
state_dict[cur_key] = load_param
|
|
|
|
|
|
|
|
for name, _ in module.experts[0].named_parameters():
|
2023-09-19 06:20:26 +00:00
|
|
|
pop_pre = f"{prefix}.experts."
|
|
|
|
pop_suf = f".{name}"
|
2023-03-31 01:20:33 +00:00
|
|
|
for i in range(num_local, module.num_total_experts):
|
2023-09-19 06:20:26 +00:00
|
|
|
pop_key = f"{pop_pre}{i}{pop_suf}"
|
2023-03-31 01:20:33 +00:00
|
|
|
state_dict.pop(pop_key)
|
|
|
|
|
|
|
|
model.load_state_dict(state_dict)
|