mirror of https://github.com/InternLM/InternLM
Merge branch 'feature_add_moe' of github.com:blankde/InternLM into feature_add_moe_pp_zl
commit
7b4933de0d
|
@ -88,7 +88,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
scale_loss: int = 1,
|
||||
moe_loss_coeff: float = 1.0,
|
||||
moe_loss_coeff: float = 0.01,
|
||||
):
|
||||
"""Trains one batch of data.
|
||||
|
||||
|
@ -136,7 +136,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
moe_loss_coeff: float = 1.0,
|
||||
moe_loss_coeff: float = 0.01,
|
||||
):
|
||||
"""The process function that loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import typing
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -31,7 +32,7 @@ def has_moe_layers(m):
|
|||
|
||||
|
||||
def is_moe_param(param: torch.Tensor) -> bool:
|
||||
if hasattr(param, "allreduce") and not param.allreduce:
|
||||
if hasattr(param, "all_reduce") and not param.all_reduce:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -95,7 +96,8 @@ class MoE(torch.nn.Module):
|
|||
"Unsupported noisy_gate_policy: " + noisy_gate_policy
|
||||
)
|
||||
|
||||
experts = Experts(experts, self.num_local_experts)
|
||||
expert_group_name = f"ep_size_{self.ep_size}"
|
||||
experts = Experts(experts, self.num_local_experts, expert_group_name)
|
||||
|
||||
if using_default_moe:
|
||||
self.moe_layer = MOELayer(
|
||||
|
@ -148,3 +150,94 @@ class MoE(torch.nn.Module):
|
|||
coef = torch.nn.functional.softmax(coef, dim=-1)
|
||||
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
|
||||
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
|
||||
|
||||
|
||||
def split_params_into_different_moe_groups_for_optimizer(
|
||||
param_groups: Tuple[Dict], max_group_size=178956971
|
||||
) -> Tuple[Dict]:
|
||||
"""Split parameters into different MoE groups for optimizer
|
||||
Compatiable with muiltiple param groups, each should have a name
|
||||
|
||||
Args:
|
||||
param_groups (Tuple[Dict]):
|
||||
The list of parameter groups to split
|
||||
|
||||
Returns:
|
||||
Tuple[Dict]:
|
||||
list of MoE/non-MoE groups for optimizer
|
||||
"""
|
||||
if isinstance(param_groups, tuple):
|
||||
param_groups = list(param_groups) # Tuple cannot be modified
|
||||
elif isinstance(param_groups, dict):
|
||||
param_groups = [param_groups]
|
||||
elif not isinstance(param_groups, list):
|
||||
raise ValueError(f"Unknown param group type of {type(param_groups)}")
|
||||
|
||||
# gather all data parallel group names
|
||||
data_parallel_group_names = set()
|
||||
for param_group in param_groups:
|
||||
for param in param_group["params"]:
|
||||
if is_moe_param(param):
|
||||
data_parallel_group_names.add(param.group_name)
|
||||
data_parallel_group_names = list(data_parallel_group_names)
|
||||
group_moe = {}
|
||||
# Create the param MoE groups, leave param assign to next step
|
||||
for param_group in param_groups:
|
||||
group_moe[param_group["name"]] = {}
|
||||
for key in data_parallel_group_names:
|
||||
group_moe[param_group["name"]][key] = {}
|
||||
group_moe[param_group["name"]][key]["name"] = key
|
||||
group_moe[param_group["name"]][key]["moe"] = True
|
||||
for ori_key in param_group.keys():
|
||||
if ori_key != "name":
|
||||
if ori_key == "params":
|
||||
group_moe[param_group["name"]][key][ori_key] = []
|
||||
else:
|
||||
group_moe[param_group["name"]][key][ori_key] = param_group[ori_key]
|
||||
# Assign param
|
||||
for param_group in param_groups:
|
||||
new_params = []
|
||||
for param in param_group["params"]:
|
||||
if is_moe_param(param):
|
||||
group_moe[param_group["name"]][param.group_name]["params"].append(param)
|
||||
# param_group['params'].remove(param)
|
||||
else:
|
||||
new_params.append(param)
|
||||
param_group["params"] = new_params
|
||||
|
||||
# Flatten the moe groups
|
||||
if max_group_size is not None:
|
||||
for _, v in group_moe.items():
|
||||
for _, v1 in v.items():
|
||||
cur_group = []
|
||||
all_groups = []
|
||||
size_of_cur_group = 0
|
||||
for param in v1["params"]:
|
||||
if size_of_cur_group + param.numel() <= max_group_size:
|
||||
cur_group.append(param)
|
||||
size_of_cur_group += param.numel()
|
||||
else:
|
||||
all_groups.append(cur_group)
|
||||
cur_group = [param]
|
||||
size_of_cur_group = param.numel()
|
||||
if cur_group:
|
||||
all_groups.append(cur_group)
|
||||
for group in all_groups:
|
||||
new_dict = {}
|
||||
for key, val in v1.items():
|
||||
if key != "params":
|
||||
new_dict[key] = val
|
||||
new_dict["params"] = group
|
||||
param_groups.append(new_dict)
|
||||
else:
|
||||
for _, v in group_moe.items():
|
||||
for _, v1 in v.items():
|
||||
param_groups.append(v1)
|
||||
|
||||
return tuple(param_groups)
|
||||
|
||||
|
||||
def create_moe_param_groups(model, weight_decay):
|
||||
parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay}
|
||||
|
||||
return split_params_into_different_moe_groups_for_optimizer(parameters)
|
||||
|
|
|
@ -21,7 +21,7 @@ class Experts(torch.nn.Module):
|
|||
Local Experts.
|
||||
"""
|
||||
|
||||
def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1):
|
||||
def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1, expert_group_name=None):
|
||||
super().__init__()
|
||||
|
||||
# TODO: We can not deepcopy FeedForward since it contains a process_group in submodules
|
||||
|
@ -38,6 +38,7 @@ class Experts(torch.nn.Module):
|
|||
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
|
||||
for _, param in expert.named_parameters():
|
||||
param.all_reduce = False
|
||||
param.group_name = expert_group_name
|
||||
|
||||
def forward(self, inputs):
|
||||
chunks = inputs.chunk(self.num_local_experts, dim=1)
|
||||
|
|
|
@ -166,6 +166,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# partition these param groups for data parallel training
|
||||
# and add buffers to parameter store for future access
|
||||
for group_id, param_group in enumerate(self.optim.param_groups):
|
||||
if "moe" in param_group.keys() and param_group["moe"]:
|
||||
print("true", flush=True)
|
||||
continue
|
||||
|
||||
group_params = param_group["params"]
|
||||
|
||||
# add the fp16 params to fp16_param_groups for bookkeeping
|
||||
|
@ -512,7 +516,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# all_groups_norm_old = all_groups_norm
|
||||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
|
||||
pg = gpc.get_group(ParallelMode.DATA)
|
||||
print(type(norm_groups))
|
||||
scaled_norm = norm_groups * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
||||
scaled_norm_tensor = torch.tensor(
|
||||
scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float
|
||||
|
|
9
train.py
9
train.py
|
@ -30,7 +30,7 @@ from internlm.data.packed_dataset import (
|
|||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||
from internlm.model.loss import FlashGPTLMLoss
|
||||
from internlm.model.metrics import AccPerplex
|
||||
from internlm.model.moe import has_moe_layers
|
||||
from internlm.model.moe import create_moe_param_groups, has_moe_layers
|
||||
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
|
@ -300,9 +300,14 @@ def initialize_optimizer(model: nn.Module):
|
|||
|
||||
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
|
||||
"""
|
||||
|
||||
adam_cfg = gpc.config.adam
|
||||
if gpc.config.model.num_experts > 1:
|
||||
params = create_moe_param_groups(model, adam_cfg.weight_decay)
|
||||
else:
|
||||
params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}]
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
|
||||
params=params,
|
||||
lr=adam_cfg.lr,
|
||||
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
|
||||
eps=adam_cfg.adam_eps,
|
||||
|
|
Loading…
Reference in New Issue