mirror of https://github.com/InternLM/InternLM
Merge pull request #3 from blankde/fix/gate_inconsistent_issue
fix(gate): gate inconsistent issuepull/182/head
commit
b2f3611b47
|
@ -123,6 +123,11 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
|
||||
for param in self.norm1.parameters():
|
||||
param.is_norm = True
|
||||
for param in self.norm2.parameters():
|
||||
param.is_norm = True
|
||||
|
||||
self.num_experts = num_experts
|
||||
self.moe_gate_k = moe_gate_k
|
||||
self.moe_capacity_factor = moe_capacity_factor
|
||||
|
|
|
@ -38,6 +38,18 @@ def is_moe_param(param: torch.Tensor) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def is_gate_param(param: torch.Tensor) -> bool:
|
||||
if hasattr(param, "is_gate") and param.is_gate:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_norm_param(param: torch.Tensor) -> bool:
|
||||
if hasattr(param, "is_norm") and param.is_norm:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class MoE(torch.nn.Module):
|
||||
"""Initialize an MoE layer.
|
||||
|
||||
|
@ -205,55 +217,85 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
|
|||
data_parallel_group_names.add(param.group_name)
|
||||
data_parallel_group_names = list(data_parallel_group_names)
|
||||
group_moe = {}
|
||||
gate_group = {}
|
||||
norm_group = {}
|
||||
# Create the param MoE groups, leave param assign to next step
|
||||
for param_group in param_groups:
|
||||
group_moe[param_group["name"]] = {}
|
||||
for key in data_parallel_group_names:
|
||||
group_moe[key] = {}
|
||||
group_moe[key]["name"] = key
|
||||
group_moe[key]["moe"] = True
|
||||
group_moe[param_group["name"]][key] = {}
|
||||
group_moe[param_group["name"]][key]["name"] = key
|
||||
group_moe[param_group["name"]][key]["moe"] = True
|
||||
for ori_key in param_group.keys():
|
||||
if ori_key != "name":
|
||||
if ori_key == "params":
|
||||
group_moe[key][ori_key] = []
|
||||
group_moe[param_group["name"]][key][ori_key] = []
|
||||
else:
|
||||
group_moe[key][ori_key] = param_group[ori_key]
|
||||
|
||||
group_moe[param_group["name"]][key][ori_key] = param_group[ori_key]
|
||||
gate_group["name"] = "gate"
|
||||
gate_group["gate"] = True
|
||||
for ori_key in param_group.keys():
|
||||
if ori_key != "name":
|
||||
if ori_key == "params":
|
||||
gate_group[ori_key] = []
|
||||
else:
|
||||
gate_group[ori_key] = param_group[ori_key]
|
||||
norm_group["name"] = "norm"
|
||||
norm_group["norm"] = True
|
||||
for ori_key in param_group.keys():
|
||||
if ori_key != "name":
|
||||
if ori_key == "params":
|
||||
norm_group[ori_key] = []
|
||||
else:
|
||||
norm_group[ori_key] = param_group[ori_key]
|
||||
# Assign param
|
||||
norm_params = []
|
||||
gate_params = []
|
||||
for param_group in param_groups:
|
||||
new_params = []
|
||||
for param in param_group["params"]:
|
||||
if is_moe_param(param):
|
||||
group_moe[param.group_name]["params"].append(param)
|
||||
# param_group['params'].remove(param)
|
||||
group_moe[param_group["name"]][param.group_name]["params"].append(param)
|
||||
elif is_norm_param(param):
|
||||
norm_params.append(param)
|
||||
elif is_gate_param(param):
|
||||
gate_params.append(param)
|
||||
else:
|
||||
new_params.append(param)
|
||||
param_group["params"] = new_params
|
||||
norm_group["params"] = norm_params
|
||||
gate_group["params"] = gate_params
|
||||
param_groups.append(norm_group)
|
||||
param_groups.append(gate_group)
|
||||
|
||||
# Flatten the moe groups
|
||||
if max_group_size is not None:
|
||||
for _, v1 in group_moe.items():
|
||||
cur_group = []
|
||||
all_groups = []
|
||||
size_of_cur_group = 0
|
||||
for param in v1["params"]:
|
||||
cur_group.append(param)
|
||||
size_of_cur_group += param.numel()
|
||||
if size_of_cur_group > max_group_size:
|
||||
for _, v in group_moe.items():
|
||||
for _, v1 in v.items():
|
||||
cur_group = []
|
||||
all_groups = []
|
||||
size_of_cur_group = 0
|
||||
for param in v1["params"]:
|
||||
if size_of_cur_group + param.numel() <= max_group_size:
|
||||
cur_group.append(param)
|
||||
size_of_cur_group += param.numel()
|
||||
else:
|
||||
all_groups.append(cur_group)
|
||||
cur_group = [param]
|
||||
size_of_cur_group = param.numel()
|
||||
if cur_group:
|
||||
all_groups.append(cur_group)
|
||||
cur_group = []
|
||||
size_of_cur_group = 0
|
||||
if cur_group:
|
||||
all_groups.append(cur_group)
|
||||
for group in all_groups:
|
||||
new_dict = {}
|
||||
for key, val in v1.items():
|
||||
if key != "params":
|
||||
new_dict[key] = val
|
||||
new_dict["params"] = group
|
||||
param_groups.append(new_dict)
|
||||
for group in all_groups:
|
||||
new_dict = {}
|
||||
for key, val in v1.items():
|
||||
if key != "params":
|
||||
new_dict[key] = val
|
||||
new_dict["params"] = group
|
||||
param_groups.append(new_dict)
|
||||
else:
|
||||
for _, v1 in group_moe.items():
|
||||
param_groups.append(v1)
|
||||
for _, v in group_moe.items():
|
||||
for _, v1 in v.items():
|
||||
param_groups.append(v1)
|
||||
return tuple(param_groups)
|
||||
|
||||
|
||||
|
|
|
@ -167,11 +167,6 @@ def _top_idx(source, k):
|
|||
return torch.topk(source, k=k, dim=0)[1]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _one_hot_to_float(x, num_classes):
|
||||
return F.one_hot(x, num_classes=num_classes).float()
|
||||
|
||||
|
||||
def top1gating(
|
||||
logits: Tensor,
|
||||
capacity_factor: float,
|
||||
|
@ -210,7 +205,7 @@ def top1gating(
|
|||
|
||||
# Compute l_aux
|
||||
me = torch.mean(gates, dim=0)
|
||||
ce = torch.mean(mask1.float(), dim=0)
|
||||
ce = torch.mean(mask1.type_as(logits), dim=0)
|
||||
l_aux = torch.sum(me * ce) * num_experts
|
||||
|
||||
# Random Token Selection
|
||||
|
@ -244,10 +239,10 @@ def top1gating(
|
|||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||
|
||||
# Normalize gate probabilities
|
||||
mask1_float = mask1.float()
|
||||
mask1_float = mask1.type_as(logits)
|
||||
gates = gates * mask1_float
|
||||
|
||||
locations1_sc = _one_hot_to_float(locations1_s, capacity)
|
||||
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).type_as(logits)
|
||||
combine_weights = einsum("se,sc->sec", gates, locations1_sc)
|
||||
|
||||
dispatch_mask = combine_weights.bool()
|
||||
|
@ -271,7 +266,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
|
|||
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
|
||||
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
||||
# Replace top-expert with min value
|
||||
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
|
||||
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), torch.finfo(logits.dtype).min)
|
||||
indices2_s = torch.argmax(logits_except1, dim=1)
|
||||
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
|
||||
|
||||
|
@ -286,7 +281,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
|
|||
|
||||
# Compute l_aux
|
||||
me = torch.mean(gates, dim=0)
|
||||
ce = torch.mean(mask1.float(), dim=0)
|
||||
ce = torch.mean(mask1.type_as(logits), dim=0)
|
||||
l_aux = torch.mean(me * ce) * num_experts * num_experts
|
||||
|
||||
# Remove locations outside capacity from mask
|
||||
|
@ -298,8 +293,8 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
|
|||
locations2_s = torch.sum(locations2 * mask2, dim=1)
|
||||
|
||||
# Normalize gate probabilities
|
||||
mask1_float = mask1.float()
|
||||
mask2_float = mask2.float()
|
||||
mask1_float = mask1.type_as(logits)
|
||||
mask2_float = mask2.type_as(logits)
|
||||
gates1_s = einsum("se,se->s", gates, mask1_float)
|
||||
gates2_s = einsum("se,se->s", gates, mask2_float)
|
||||
denom_s = gates1_s + gates2_s
|
||||
|
@ -311,8 +306,8 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
|
|||
# Calculate combine_weights and dispatch_mask
|
||||
gates1 = einsum("s,se->se", gates1_s, mask1_float)
|
||||
gates2 = einsum("s,se->se", gates2_s, mask2_float)
|
||||
locations1_sc = _one_hot_to_float(locations1_s, capacity)
|
||||
locations2_sc = _one_hot_to_float(locations2_s, capacity)
|
||||
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).type_as(logits)
|
||||
locations2_sc = F.one_hot(locations2_s, num_classes=capacity).type_as(logits)
|
||||
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
|
||||
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
|
||||
combine_weights = combine1_sec + combine2_sec
|
||||
|
@ -357,7 +352,7 @@ class TopKGate(Module):
|
|||
if k not in (1, 2):
|
||||
raise ValueError("Only top-1 and top-2 gatings are supported.")
|
||||
# Deepspeed's mechisms, alway use fp32
|
||||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
|
||||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
|
||||
self.k = k
|
||||
self.capacity_factor = capacity_factor
|
||||
self.eval_capacity_factor = eval_capacity_factor
|
||||
|
@ -368,6 +363,9 @@ class TopKGate(Module):
|
|||
self.drop_tokens = drop_tokens
|
||||
self.use_rts = use_rts
|
||||
|
||||
for param in self.wg.parameters():
|
||||
param.is_gate = True
|
||||
|
||||
def forward(
|
||||
self, inputs: torch.Tensor, used_token: torch.Tensor = None
|
||||
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
|
||||
|
@ -375,13 +373,10 @@ class TopKGate(Module):
|
|||
if self.wall_clock_breakdown:
|
||||
timer("TopKGate").start()
|
||||
|
||||
if self.wg.weight.dtype != torch.float32: # TODO can we change it to fp16
|
||||
self.wg = self.wg.float()
|
||||
inputs_fp32 = inputs.float()
|
||||
# input jittering
|
||||
if self.noisy_gate_policy == "Jitter" and self.training:
|
||||
inputs_fp32 = multiplicative_jitter(inputs_fp32, device=inputs.device)
|
||||
logits = self.wg(inputs_fp32)
|
||||
inputs = multiplicative_jitter(inputs, device=inputs.device)
|
||||
logits = self.wg(inputs)
|
||||
|
||||
if self.k == 1:
|
||||
gate_output = top1gating(
|
||||
|
|
|
@ -308,6 +308,12 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
def _is_moe_group(self, param_group):
|
||||
return "moe" in param_group.keys() and param_group["moe"]
|
||||
|
||||
def _is_norm_group(self, param_group):
|
||||
return "norm" in param_group.keys() and param_group["norm"]
|
||||
|
||||
def _is_gate_group(self, param_group):
|
||||
return "gate" in param_group.keys() and param_group["gate"]
|
||||
|
||||
def _attach_reduction_hook(self):
|
||||
# we iterate over the fp16 params
|
||||
# on each param, we register a hook to its AccumulateGrad object
|
||||
|
@ -688,6 +694,22 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
param_shape == flat_fp32_avg_grads.shape
|
||||
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
|
||||
|
||||
# Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients.
|
||||
# Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors.
|
||||
if self._is_norm_group(self.optim.param_groups[group_id]):
|
||||
dist.all_reduce(
|
||||
flat_fp32_avg_grads,
|
||||
op=dist.ReduceOp.AVG,
|
||||
group=gpc.get_group(ParallelMode.TENSOR),
|
||||
)
|
||||
|
||||
if self._is_gate_group(self.optim.param_groups[group_id]):
|
||||
dist.all_reduce(
|
||||
flat_fp32_avg_grads,
|
||||
op=dist.ReduceOp.AVG,
|
||||
group=gpc.get_group(ParallelMode.TENSOR),
|
||||
)
|
||||
|
||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||
|
|
Loading…
Reference in New Issue