mirror of https://github.com/InternLM/InternLM
Merge pull request #3 from blankde/fix/gate_inconsistent_issue
fix(gate): gate inconsistent issuepull/182/head
commit
b2f3611b47
|
@ -123,6 +123,11 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
|
|
||||||
|
for param in self.norm1.parameters():
|
||||||
|
param.is_norm = True
|
||||||
|
for param in self.norm2.parameters():
|
||||||
|
param.is_norm = True
|
||||||
|
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.moe_gate_k = moe_gate_k
|
self.moe_gate_k = moe_gate_k
|
||||||
self.moe_capacity_factor = moe_capacity_factor
|
self.moe_capacity_factor = moe_capacity_factor
|
||||||
|
|
|
@ -38,6 +38,18 @@ def is_moe_param(param: torch.Tensor) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_gate_param(param: torch.Tensor) -> bool:
|
||||||
|
if hasattr(param, "is_gate") and param.is_gate:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_norm_param(param: torch.Tensor) -> bool:
|
||||||
|
if hasattr(param, "is_norm") and param.is_norm:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class MoE(torch.nn.Module):
|
class MoE(torch.nn.Module):
|
||||||
"""Initialize an MoE layer.
|
"""Initialize an MoE layer.
|
||||||
|
|
||||||
|
@ -205,55 +217,85 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
|
||||||
data_parallel_group_names.add(param.group_name)
|
data_parallel_group_names.add(param.group_name)
|
||||||
data_parallel_group_names = list(data_parallel_group_names)
|
data_parallel_group_names = list(data_parallel_group_names)
|
||||||
group_moe = {}
|
group_moe = {}
|
||||||
|
gate_group = {}
|
||||||
|
norm_group = {}
|
||||||
# Create the param MoE groups, leave param assign to next step
|
# Create the param MoE groups, leave param assign to next step
|
||||||
for param_group in param_groups:
|
for param_group in param_groups:
|
||||||
|
group_moe[param_group["name"]] = {}
|
||||||
for key in data_parallel_group_names:
|
for key in data_parallel_group_names:
|
||||||
group_moe[key] = {}
|
group_moe[param_group["name"]][key] = {}
|
||||||
group_moe[key]["name"] = key
|
group_moe[param_group["name"]][key]["name"] = key
|
||||||
group_moe[key]["moe"] = True
|
group_moe[param_group["name"]][key]["moe"] = True
|
||||||
for ori_key in param_group.keys():
|
for ori_key in param_group.keys():
|
||||||
if ori_key != "name":
|
if ori_key != "name":
|
||||||
if ori_key == "params":
|
if ori_key == "params":
|
||||||
group_moe[key][ori_key] = []
|
group_moe[param_group["name"]][key][ori_key] = []
|
||||||
else:
|
else:
|
||||||
group_moe[key][ori_key] = param_group[ori_key]
|
group_moe[param_group["name"]][key][ori_key] = param_group[ori_key]
|
||||||
|
gate_group["name"] = "gate"
|
||||||
|
gate_group["gate"] = True
|
||||||
|
for ori_key in param_group.keys():
|
||||||
|
if ori_key != "name":
|
||||||
|
if ori_key == "params":
|
||||||
|
gate_group[ori_key] = []
|
||||||
|
else:
|
||||||
|
gate_group[ori_key] = param_group[ori_key]
|
||||||
|
norm_group["name"] = "norm"
|
||||||
|
norm_group["norm"] = True
|
||||||
|
for ori_key in param_group.keys():
|
||||||
|
if ori_key != "name":
|
||||||
|
if ori_key == "params":
|
||||||
|
norm_group[ori_key] = []
|
||||||
|
else:
|
||||||
|
norm_group[ori_key] = param_group[ori_key]
|
||||||
# Assign param
|
# Assign param
|
||||||
|
norm_params = []
|
||||||
|
gate_params = []
|
||||||
for param_group in param_groups:
|
for param_group in param_groups:
|
||||||
new_params = []
|
new_params = []
|
||||||
for param in param_group["params"]:
|
for param in param_group["params"]:
|
||||||
if is_moe_param(param):
|
if is_moe_param(param):
|
||||||
group_moe[param.group_name]["params"].append(param)
|
group_moe[param_group["name"]][param.group_name]["params"].append(param)
|
||||||
# param_group['params'].remove(param)
|
elif is_norm_param(param):
|
||||||
|
norm_params.append(param)
|
||||||
|
elif is_gate_param(param):
|
||||||
|
gate_params.append(param)
|
||||||
else:
|
else:
|
||||||
new_params.append(param)
|
new_params.append(param)
|
||||||
param_group["params"] = new_params
|
param_group["params"] = new_params
|
||||||
|
norm_group["params"] = norm_params
|
||||||
|
gate_group["params"] = gate_params
|
||||||
|
param_groups.append(norm_group)
|
||||||
|
param_groups.append(gate_group)
|
||||||
|
|
||||||
# Flatten the moe groups
|
# Flatten the moe groups
|
||||||
if max_group_size is not None:
|
if max_group_size is not None:
|
||||||
for _, v1 in group_moe.items():
|
for _, v in group_moe.items():
|
||||||
cur_group = []
|
for _, v1 in v.items():
|
||||||
all_groups = []
|
cur_group = []
|
||||||
size_of_cur_group = 0
|
all_groups = []
|
||||||
for param in v1["params"]:
|
size_of_cur_group = 0
|
||||||
cur_group.append(param)
|
for param in v1["params"]:
|
||||||
size_of_cur_group += param.numel()
|
if size_of_cur_group + param.numel() <= max_group_size:
|
||||||
if size_of_cur_group > max_group_size:
|
cur_group.append(param)
|
||||||
|
size_of_cur_group += param.numel()
|
||||||
|
else:
|
||||||
|
all_groups.append(cur_group)
|
||||||
|
cur_group = [param]
|
||||||
|
size_of_cur_group = param.numel()
|
||||||
|
if cur_group:
|
||||||
all_groups.append(cur_group)
|
all_groups.append(cur_group)
|
||||||
cur_group = []
|
for group in all_groups:
|
||||||
size_of_cur_group = 0
|
new_dict = {}
|
||||||
if cur_group:
|
for key, val in v1.items():
|
||||||
all_groups.append(cur_group)
|
if key != "params":
|
||||||
for group in all_groups:
|
new_dict[key] = val
|
||||||
new_dict = {}
|
new_dict["params"] = group
|
||||||
for key, val in v1.items():
|
param_groups.append(new_dict)
|
||||||
if key != "params":
|
|
||||||
new_dict[key] = val
|
|
||||||
new_dict["params"] = group
|
|
||||||
param_groups.append(new_dict)
|
|
||||||
else:
|
else:
|
||||||
for _, v1 in group_moe.items():
|
for _, v in group_moe.items():
|
||||||
param_groups.append(v1)
|
for _, v1 in v.items():
|
||||||
|
param_groups.append(v1)
|
||||||
return tuple(param_groups)
|
return tuple(param_groups)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -167,11 +167,6 @@ def _top_idx(source, k):
|
||||||
return torch.topk(source, k=k, dim=0)[1]
|
return torch.topk(source, k=k, dim=0)[1]
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def _one_hot_to_float(x, num_classes):
|
|
||||||
return F.one_hot(x, num_classes=num_classes).float()
|
|
||||||
|
|
||||||
|
|
||||||
def top1gating(
|
def top1gating(
|
||||||
logits: Tensor,
|
logits: Tensor,
|
||||||
capacity_factor: float,
|
capacity_factor: float,
|
||||||
|
@ -210,7 +205,7 @@ def top1gating(
|
||||||
|
|
||||||
# Compute l_aux
|
# Compute l_aux
|
||||||
me = torch.mean(gates, dim=0)
|
me = torch.mean(gates, dim=0)
|
||||||
ce = torch.mean(mask1.float(), dim=0)
|
ce = torch.mean(mask1.type_as(logits), dim=0)
|
||||||
l_aux = torch.sum(me * ce) * num_experts
|
l_aux = torch.sum(me * ce) * num_experts
|
||||||
|
|
||||||
# Random Token Selection
|
# Random Token Selection
|
||||||
|
@ -244,10 +239,10 @@ def top1gating(
|
||||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||||
|
|
||||||
# Normalize gate probabilities
|
# Normalize gate probabilities
|
||||||
mask1_float = mask1.float()
|
mask1_float = mask1.type_as(logits)
|
||||||
gates = gates * mask1_float
|
gates = gates * mask1_float
|
||||||
|
|
||||||
locations1_sc = _one_hot_to_float(locations1_s, capacity)
|
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).type_as(logits)
|
||||||
combine_weights = einsum("se,sc->sec", gates, locations1_sc)
|
combine_weights = einsum("se,sc->sec", gates, locations1_sc)
|
||||||
|
|
||||||
dispatch_mask = combine_weights.bool()
|
dispatch_mask = combine_weights.bool()
|
||||||
|
@ -271,7 +266,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
|
||||||
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
|
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
|
||||||
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
||||||
# Replace top-expert with min value
|
# Replace top-expert with min value
|
||||||
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
|
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), torch.finfo(logits.dtype).min)
|
||||||
indices2_s = torch.argmax(logits_except1, dim=1)
|
indices2_s = torch.argmax(logits_except1, dim=1)
|
||||||
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
|
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
|
||||||
|
|
||||||
|
@ -286,7 +281,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
|
||||||
|
|
||||||
# Compute l_aux
|
# Compute l_aux
|
||||||
me = torch.mean(gates, dim=0)
|
me = torch.mean(gates, dim=0)
|
||||||
ce = torch.mean(mask1.float(), dim=0)
|
ce = torch.mean(mask1.type_as(logits), dim=0)
|
||||||
l_aux = torch.mean(me * ce) * num_experts * num_experts
|
l_aux = torch.mean(me * ce) * num_experts * num_experts
|
||||||
|
|
||||||
# Remove locations outside capacity from mask
|
# Remove locations outside capacity from mask
|
||||||
|
@ -298,8 +293,8 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
|
||||||
locations2_s = torch.sum(locations2 * mask2, dim=1)
|
locations2_s = torch.sum(locations2 * mask2, dim=1)
|
||||||
|
|
||||||
# Normalize gate probabilities
|
# Normalize gate probabilities
|
||||||
mask1_float = mask1.float()
|
mask1_float = mask1.type_as(logits)
|
||||||
mask2_float = mask2.float()
|
mask2_float = mask2.type_as(logits)
|
||||||
gates1_s = einsum("se,se->s", gates, mask1_float)
|
gates1_s = einsum("se,se->s", gates, mask1_float)
|
||||||
gates2_s = einsum("se,se->s", gates, mask2_float)
|
gates2_s = einsum("se,se->s", gates, mask2_float)
|
||||||
denom_s = gates1_s + gates2_s
|
denom_s = gates1_s + gates2_s
|
||||||
|
@ -311,8 +306,8 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
|
||||||
# Calculate combine_weights and dispatch_mask
|
# Calculate combine_weights and dispatch_mask
|
||||||
gates1 = einsum("s,se->se", gates1_s, mask1_float)
|
gates1 = einsum("s,se->se", gates1_s, mask1_float)
|
||||||
gates2 = einsum("s,se->se", gates2_s, mask2_float)
|
gates2 = einsum("s,se->se", gates2_s, mask2_float)
|
||||||
locations1_sc = _one_hot_to_float(locations1_s, capacity)
|
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).type_as(logits)
|
||||||
locations2_sc = _one_hot_to_float(locations2_s, capacity)
|
locations2_sc = F.one_hot(locations2_s, num_classes=capacity).type_as(logits)
|
||||||
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
|
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
|
||||||
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
|
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
|
||||||
combine_weights = combine1_sec + combine2_sec
|
combine_weights = combine1_sec + combine2_sec
|
||||||
|
@ -357,7 +352,7 @@ class TopKGate(Module):
|
||||||
if k not in (1, 2):
|
if k not in (1, 2):
|
||||||
raise ValueError("Only top-1 and top-2 gatings are supported.")
|
raise ValueError("Only top-1 and top-2 gatings are supported.")
|
||||||
# Deepspeed's mechisms, alway use fp32
|
# Deepspeed's mechisms, alway use fp32
|
||||||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
|
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
|
||||||
self.k = k
|
self.k = k
|
||||||
self.capacity_factor = capacity_factor
|
self.capacity_factor = capacity_factor
|
||||||
self.eval_capacity_factor = eval_capacity_factor
|
self.eval_capacity_factor = eval_capacity_factor
|
||||||
|
@ -368,6 +363,9 @@ class TopKGate(Module):
|
||||||
self.drop_tokens = drop_tokens
|
self.drop_tokens = drop_tokens
|
||||||
self.use_rts = use_rts
|
self.use_rts = use_rts
|
||||||
|
|
||||||
|
for param in self.wg.parameters():
|
||||||
|
param.is_gate = True
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, inputs: torch.Tensor, used_token: torch.Tensor = None
|
self, inputs: torch.Tensor, used_token: torch.Tensor = None
|
||||||
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
|
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
|
||||||
|
@ -375,13 +373,10 @@ class TopKGate(Module):
|
||||||
if self.wall_clock_breakdown:
|
if self.wall_clock_breakdown:
|
||||||
timer("TopKGate").start()
|
timer("TopKGate").start()
|
||||||
|
|
||||||
if self.wg.weight.dtype != torch.float32: # TODO can we change it to fp16
|
|
||||||
self.wg = self.wg.float()
|
|
||||||
inputs_fp32 = inputs.float()
|
|
||||||
# input jittering
|
# input jittering
|
||||||
if self.noisy_gate_policy == "Jitter" and self.training:
|
if self.noisy_gate_policy == "Jitter" and self.training:
|
||||||
inputs_fp32 = multiplicative_jitter(inputs_fp32, device=inputs.device)
|
inputs = multiplicative_jitter(inputs, device=inputs.device)
|
||||||
logits = self.wg(inputs_fp32)
|
logits = self.wg(inputs)
|
||||||
|
|
||||||
if self.k == 1:
|
if self.k == 1:
|
||||||
gate_output = top1gating(
|
gate_output = top1gating(
|
||||||
|
|
|
@ -308,6 +308,12 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
def _is_moe_group(self, param_group):
|
def _is_moe_group(self, param_group):
|
||||||
return "moe" in param_group.keys() and param_group["moe"]
|
return "moe" in param_group.keys() and param_group["moe"]
|
||||||
|
|
||||||
|
def _is_norm_group(self, param_group):
|
||||||
|
return "norm" in param_group.keys() and param_group["norm"]
|
||||||
|
|
||||||
|
def _is_gate_group(self, param_group):
|
||||||
|
return "gate" in param_group.keys() and param_group["gate"]
|
||||||
|
|
||||||
def _attach_reduction_hook(self):
|
def _attach_reduction_hook(self):
|
||||||
# we iterate over the fp16 params
|
# we iterate over the fp16 params
|
||||||
# on each param, we register a hook to its AccumulateGrad object
|
# on each param, we register a hook to its AccumulateGrad object
|
||||||
|
@ -688,6 +694,22 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
param_shape == flat_fp32_avg_grads.shape
|
param_shape == flat_fp32_avg_grads.shape
|
||||||
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
|
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
|
||||||
|
|
||||||
|
# Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients.
|
||||||
|
# Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors.
|
||||||
|
if self._is_norm_group(self.optim.param_groups[group_id]):
|
||||||
|
dist.all_reduce(
|
||||||
|
flat_fp32_avg_grads,
|
||||||
|
op=dist.ReduceOp.AVG,
|
||||||
|
group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._is_gate_group(self.optim.param_groups[group_id]):
|
||||||
|
dist.all_reduce(
|
||||||
|
flat_fp32_avg_grads,
|
||||||
|
op=dist.ReduceOp.AVG,
|
||||||
|
group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
|
)
|
||||||
|
|
||||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||||
|
|
Loading…
Reference in New Issue