Merge branch 'feature_add_moe' of github.com:blankde/InternLM into feature_add_moe_pp_zl

pull/182/head
zhanglei 2023-08-15 18:49:55 +08:00
commit 7b4933de0d
5 changed files with 110 additions and 8 deletions

View File

@ -88,7 +88,7 @@ class NonPipelineScheduler(BaseScheduler):
forward_only: bool = False,
return_loss: bool = True,
scale_loss: int = 1,
moe_loss_coeff: float = 1.0,
moe_loss_coeff: float = 0.01,
):
"""Trains one batch of data.
@ -136,7 +136,7 @@ class NonPipelineScheduler(BaseScheduler):
forward_only: bool = False,
return_loss: bool = True,
return_output_label: bool = True,
moe_loss_coeff: float = 1.0,
moe_loss_coeff: float = 0.01,
):
"""The process function that loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.

View File

@ -1,4 +1,5 @@
import typing
from typing import Dict, Tuple
import torch
@ -31,7 +32,7 @@ def has_moe_layers(m):
def is_moe_param(param: torch.Tensor) -> bool:
if hasattr(param, "allreduce") and not param.allreduce:
if hasattr(param, "all_reduce") and not param.all_reduce:
return True
return False
@ -95,7 +96,8 @@ class MoE(torch.nn.Module):
"Unsupported noisy_gate_policy: " + noisy_gate_policy
)
experts = Experts(experts, self.num_local_experts)
expert_group_name = f"ep_size_{self.ep_size}"
experts = Experts(experts, self.num_local_experts, expert_group_name)
if using_default_moe:
self.moe_layer = MOELayer(
@ -148,3 +150,94 @@ class MoE(torch.nn.Module):
coef = torch.nn.functional.softmax(coef, dim=-1)
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
def split_params_into_different_moe_groups_for_optimizer(
param_groups: Tuple[Dict], max_group_size=178956971
) -> Tuple[Dict]:
"""Split parameters into different MoE groups for optimizer
Compatiable with muiltiple param groups, each should have a name
Args:
param_groups (Tuple[Dict]):
The list of parameter groups to split
Returns:
Tuple[Dict]:
list of MoE/non-MoE groups for optimizer
"""
if isinstance(param_groups, tuple):
param_groups = list(param_groups) # Tuple cannot be modified
elif isinstance(param_groups, dict):
param_groups = [param_groups]
elif not isinstance(param_groups, list):
raise ValueError(f"Unknown param group type of {type(param_groups)}")
# gather all data parallel group names
data_parallel_group_names = set()
for param_group in param_groups:
for param in param_group["params"]:
if is_moe_param(param):
data_parallel_group_names.add(param.group_name)
data_parallel_group_names = list(data_parallel_group_names)
group_moe = {}
# Create the param MoE groups, leave param assign to next step
for param_group in param_groups:
group_moe[param_group["name"]] = {}
for key in data_parallel_group_names:
group_moe[param_group["name"]][key] = {}
group_moe[param_group["name"]][key]["name"] = key
group_moe[param_group["name"]][key]["moe"] = True
for ori_key in param_group.keys():
if ori_key != "name":
if ori_key == "params":
group_moe[param_group["name"]][key][ori_key] = []
else:
group_moe[param_group["name"]][key][ori_key] = param_group[ori_key]
# Assign param
for param_group in param_groups:
new_params = []
for param in param_group["params"]:
if is_moe_param(param):
group_moe[param_group["name"]][param.group_name]["params"].append(param)
# param_group['params'].remove(param)
else:
new_params.append(param)
param_group["params"] = new_params
# Flatten the moe groups
if max_group_size is not None:
for _, v in group_moe.items():
for _, v1 in v.items():
cur_group = []
all_groups = []
size_of_cur_group = 0
for param in v1["params"]:
if size_of_cur_group + param.numel() <= max_group_size:
cur_group.append(param)
size_of_cur_group += param.numel()
else:
all_groups.append(cur_group)
cur_group = [param]
size_of_cur_group = param.numel()
if cur_group:
all_groups.append(cur_group)
for group in all_groups:
new_dict = {}
for key, val in v1.items():
if key != "params":
new_dict[key] = val
new_dict["params"] = group
param_groups.append(new_dict)
else:
for _, v in group_moe.items():
for _, v1 in v.items():
param_groups.append(v1)
return tuple(param_groups)
def create_moe_param_groups(model, weight_decay):
parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay}
return split_params_into_different_moe_groups_for_optimizer(parameters)

View File

@ -21,7 +21,7 @@ class Experts(torch.nn.Module):
Local Experts.
"""
def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1):
def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1, expert_group_name=None):
super().__init__()
# TODO: We can not deepcopy FeedForward since it contains a process_group in submodules
@ -38,6 +38,7 @@ class Experts(torch.nn.Module):
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
for _, param in expert.named_parameters():
param.all_reduce = False
param.group_name = expert_group_name
def forward(self, inputs):
chunks = inputs.chunk(self.num_local_experts, dim=1)

View File

@ -166,6 +166,10 @@ class HybridZeroOptimizer(BaseOptimizer):
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for group_id, param_group in enumerate(self.optim.param_groups):
if "moe" in param_group.keys() and param_group["moe"]:
print("true", flush=True)
continue
group_params = param_group["params"]
# add the fp16 params to fp16_param_groups for bookkeeping
@ -512,7 +516,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# all_groups_norm_old = all_groups_norm
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
pg = gpc.get_group(ParallelMode.DATA)
print(type(norm_groups))
scaled_norm = norm_groups * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
scaled_norm_tensor = torch.tensor(
scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float

View File

@ -30,7 +30,7 @@ from internlm.data.packed_dataset import (
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.model.moe import has_moe_layers
from internlm.model.moe import create_moe_param_groups, has_moe_layers
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
@ -300,9 +300,14 @@ def initialize_optimizer(model: nn.Module):
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
"""
adam_cfg = gpc.config.adam
if gpc.config.model.num_experts > 1:
params = create_moe_param_groups(model, adam_cfg.weight_decay)
else:
params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}]
naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
params=params,
lr=adam_cfg.lr,
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
eps=adam_cfg.adam_eps,