refactor code

pull/182/head
Qu Wenwen 2023-09-19 12:30:40 +08:00
parent b2f3611b47
commit 4a47872382
7 changed files with 103 additions and 151 deletions

View File

@ -149,11 +149,16 @@ class ParallelContext(metaclass=SingletonMeta):
self.num_processes_on_current_node = -1 self.num_processes_on_current_node = -1
self.virtual_pipeline_parallel_size = None self.virtual_pipeline_parallel_size = None
self.virtual_pipeline_parallel_rank = None self.virtual_pipeline_parallel_rank = None
self._expert_parallel_group_names = []
@property @property
def config(self): def config(self):
return self._config return self._config
@property
def expert_parallel_group_names(self):
return self._expert_parallel_group_names
def load_config(self, config: Union[dict, str]): def load_config(self, config: Union[dict, str]):
"""Loads the configuration from either a dict or a file. """Loads the configuration from either a dict or a file.

View File

@ -1,5 +1,4 @@
import typing import typing
from typing import Dict, Tuple
import torch import torch
@ -20,36 +19,6 @@ from internlm.utils.logger import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
def has_moe_layers(m):
has_moe = False
num_experts = 0
for _, module in m.named_modules():
if isinstance(module, MoE):
has_moe = True
num_experts = module.num_experts
break
return has_moe, num_experts
def is_moe_param(param: torch.Tensor) -> bool:
if hasattr(param, "is_expert") and param.is_expert:
return True
return False
def is_gate_param(param: torch.Tensor) -> bool:
if hasattr(param, "is_gate") and param.is_gate:
return True
return False
def is_norm_param(param: torch.Tensor) -> bool:
if hasattr(param, "is_norm") and param.is_norm:
return True
return False
class MoE(torch.nn.Module): class MoE(torch.nn.Module):
"""Initialize an MoE layer. """Initialize an MoE layer.
@ -110,7 +79,9 @@ class MoE(torch.nn.Module):
) )
# for elastic expert paralle, experts may have multiple groups # for elastic expert paralle, experts may have multiple groups
expert_group_name = f"ep_size_{self.ep_size}" expert_group_name = f"moe_ep_size_{self.ep_size}"
if expert_group_name not in gpc.expert_parallel_group_names:
gpc.expert_parallel_group_names.append(expert_group_name)
experts = torch.nn.ModuleList( experts = torch.nn.ModuleList(
[ [
# TODO have trouble when use internlm.model.linear.FeedForward # TODO have trouble when use internlm.model.linear.FeedForward
@ -188,118 +159,3 @@ class MoE(torch.nn.Module):
coef = torch.nn.functional.softmax(coef, dim=-1) coef = torch.nn.functional.softmax(coef, dim=-1)
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict], max_group_size=None) -> Tuple[Dict]:
"""Split parameters into different MoE groups for optimizer
Compatiable with muiltiple param groups, each should have a name
Args:
param_groups (Tuple[Dict]):
The list of parameter groups to split
Returns:
Tuple[Dict]:
list of MoE/non-MoE groups for optimizer
"""
if isinstance(param_groups, tuple):
param_groups = list(param_groups) # Tuple cannot be modified
elif isinstance(param_groups, dict):
param_groups = [param_groups]
elif not isinstance(param_groups, list):
raise ValueError(f"Unknown param group type of {type(param_groups)}")
# gather all data parallel group names
data_parallel_group_names = set()
for param_group in param_groups:
for param in param_group["params"]:
if is_moe_param(param):
data_parallel_group_names.add(param.group_name)
data_parallel_group_names = list(data_parallel_group_names)
group_moe = {}
gate_group = {}
norm_group = {}
# Create the param MoE groups, leave param assign to next step
for param_group in param_groups:
group_moe[param_group["name"]] = {}
for key in data_parallel_group_names:
group_moe[param_group["name"]][key] = {}
group_moe[param_group["name"]][key]["name"] = key
group_moe[param_group["name"]][key]["moe"] = True
for ori_key in param_group.keys():
if ori_key != "name":
if ori_key == "params":
group_moe[param_group["name"]][key][ori_key] = []
else:
group_moe[param_group["name"]][key][ori_key] = param_group[ori_key]
gate_group["name"] = "gate"
gate_group["gate"] = True
for ori_key in param_group.keys():
if ori_key != "name":
if ori_key == "params":
gate_group[ori_key] = []
else:
gate_group[ori_key] = param_group[ori_key]
norm_group["name"] = "norm"
norm_group["norm"] = True
for ori_key in param_group.keys():
if ori_key != "name":
if ori_key == "params":
norm_group[ori_key] = []
else:
norm_group[ori_key] = param_group[ori_key]
# Assign param
norm_params = []
gate_params = []
for param_group in param_groups:
new_params = []
for param in param_group["params"]:
if is_moe_param(param):
group_moe[param_group["name"]][param.group_name]["params"].append(param)
elif is_norm_param(param):
norm_params.append(param)
elif is_gate_param(param):
gate_params.append(param)
else:
new_params.append(param)
param_group["params"] = new_params
norm_group["params"] = norm_params
gate_group["params"] = gate_params
param_groups.append(norm_group)
param_groups.append(gate_group)
# Flatten the moe groups
if max_group_size is not None:
for _, v in group_moe.items():
for _, v1 in v.items():
cur_group = []
all_groups = []
size_of_cur_group = 0
for param in v1["params"]:
if size_of_cur_group + param.numel() <= max_group_size:
cur_group.append(param)
size_of_cur_group += param.numel()
else:
all_groups.append(cur_group)
cur_group = [param]
size_of_cur_group = param.numel()
if cur_group:
all_groups.append(cur_group)
for group in all_groups:
new_dict = {}
for key, val in v1.items():
if key != "params":
new_dict[key] = val
new_dict["params"] = group
param_groups.append(new_dict)
else:
for _, v in group_moe.items():
for _, v1 in v.items():
param_groups.append(v1)
return tuple(param_groups)
def create_moe_param_groups(model, weight_decay):
parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay}
return split_params_into_different_moe_groups_for_optimizer(parameters)

View File

@ -207,3 +207,21 @@ def try_import_RMSNorm():
from internlm.model.norm import RMSNormTorch as RMSNorm from internlm.model.norm import RMSNormTorch as RMSNorm
return RMSNorm return RMSNorm
def is_moe_param(param: torch.Tensor) -> bool:
if hasattr(param, "is_expert") and param.is_expert:
return True
return False
def is_gate_param(param: torch.Tensor) -> bool:
if hasattr(param, "is_gate") and param.is_gate:
return True
return False
def is_norm_param(param: torch.Tensor) -> bool:
if hasattr(param, "is_norm") and param.is_norm:
return True
return False

View File

@ -11,7 +11,7 @@ from torch.optim import Optimizer
from internlm.core.context import Config, ParallelMode from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.moe import is_moe_param from internlm.model.utils import is_moe_param
from internlm.monitor import send_alert_message from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import ( from internlm.solver.optimizer.store import (
BucketStore, BucketStore,

View File

@ -25,13 +25,13 @@ from internlm.data.packed_dataset import (
get_packed_dataset_without_short_length, get_packed_dataset_without_short_length,
) )
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.moe import create_moe_param_groups
from internlm.monitor import send_heartbeat, set_env_var from internlm.monitor import send_heartbeat, set_env_var
from internlm.monitor.monitor import monitor_manager as mm from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.train.utils import create_param_groups
from internlm.utils.common import DummyProfile from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
@ -112,7 +112,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
adam_cfg = gpc.config.adam adam_cfg = gpc.config.adam
# split the moe parameters into different groups # split the moe parameters into different groups
if gpc.config.model.num_experts > 1: if gpc.config.model.num_experts > 1:
params = create_moe_param_groups(model, adam_cfg.weight_decay) params = create_param_groups(model, adam_cfg.weight_decay)
else: else:
params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}] params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}]
naive_optimizer = torch.optim.AdamW( naive_optimizer = torch.optim.AdamW(

73
internlm/train/utils.py Normal file
View File

@ -0,0 +1,73 @@
from typing import Dict, Tuple
from internlm.core.context.parallel_context import global_context as gpc
from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param
def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]:
"""Split parameters into different MoE groups for optimizer
Compatiable with muiltiple param groups, each should have a name
Args:
param_groups (Tuple[Dict]):
The list of parameter groups to split
Returns:
Tuple[Dict]:
list of MoE/non-MoE groups for optimizer
"""
if isinstance(param_groups, tuple):
param_groups = list(param_groups) # Tuple cannot be modified
elif isinstance(param_groups, dict):
param_groups = [param_groups]
elif not isinstance(param_groups, list):
raise ValueError(f"Unknown param group type of {type(param_groups)}")
new_groups = {}
for pgroup in param_groups:
new_groups[pgroup["name"]] = {}
# create new groups for gate and norm
for key in ["gate", "norm"]:
new_groups[pgroup["name"]][key] = {}
new_groups[pgroup["name"]][key]["name"] = key
new_groups[pgroup["name"]][key][key] = True
# create moe groups
for key in gpc.expert_parallel_group_names:
new_groups[pgroup["name"]][key] = {}
new_groups[pgroup["name"]][key]["name"] = key
new_groups[pgroup["name"]][key]["moe"] = True
# copy attribute from origin group
for ori_key in pgroup.keys():
for key in new_groups[pgroup["name"]].keys():
if ori_key != "name":
if ori_key == "params":
new_groups[pgroup["name"]][key][ori_key] = []
else:
new_groups[pgroup["name"]][key][ori_key] = pgroup[ori_key]
# Assign param
origin_params = []
for param in pgroup["params"]:
if is_moe_param(param):
new_groups[pgroup["name"]][param.group_name]["params"].append(param)
elif is_norm_param(param):
new_groups[pgroup["name"]]["norm"]["params"].append(param)
elif is_gate_param(param):
new_groups[pgroup["name"]]["gate"]["params"].append(param)
else:
origin_params.append(param)
pgroup["params"] = origin_params
for _, v in new_groups.items():
for _, v1 in v.items():
param_groups.append(v1)
return tuple(param_groups)
def create_param_groups(model, weight_decay):
parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay}
return split_params_into_different_groups_for_optimizer(parameters)

View File

@ -5,7 +5,7 @@ import torch.distributed as dist
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.moe import is_moe_param from internlm.model.utils import is_moe_param
def is_model_parallel_parameter(p): def is_model_parallel_parameter(p):