fix precision inconsistency

pull/182/head
zhanglei 2023-09-18 20:54:52 +08:00
parent 5aa5c96ec8
commit edc18bcddd
4 changed files with 113 additions and 49 deletions

View File

@ -123,6 +123,11 @@ class PackedFlashBaseLayer1D(nn.Module):
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
for param in self.norm1.parameters():
param.is_norm = True
for param in self.norm2.parameters():
param.is_norm = True
self.num_experts = num_experts
self.moe_gate_k = moe_gate_k
self.moe_capacity_factor = moe_capacity_factor

View File

@ -38,6 +38,18 @@ def is_moe_param(param: torch.Tensor) -> bool:
return False
def is_gate_param(param: torch.Tensor) -> bool:
if hasattr(param, "is_gate") and param.is_gate:
return True
return False
def is_norm_param(param: torch.Tensor) -> bool:
if hasattr(param, "is_norm") and param.is_norm:
return True
return False
class MoE(torch.nn.Module):
"""Initialize an MoE layer.
@ -205,55 +217,85 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
data_parallel_group_names.add(param.group_name)
data_parallel_group_names = list(data_parallel_group_names)
group_moe = {}
gate_group = {}
norm_group = {}
# Create the param MoE groups, leave param assign to next step
for param_group in param_groups:
group_moe[param_group["name"]] = {}
for key in data_parallel_group_names:
group_moe[key] = {}
group_moe[key]["name"] = key
group_moe[key]["moe"] = True
group_moe[param_group["name"]][key] = {}
group_moe[param_group["name"]][key]["name"] = key
group_moe[param_group["name"]][key]["moe"] = True
for ori_key in param_group.keys():
if ori_key != "name":
if ori_key == "params":
group_moe[key][ori_key] = []
group_moe[param_group["name"]][key][ori_key] = []
else:
group_moe[key][ori_key] = param_group[ori_key]
group_moe[param_group["name"]][key][ori_key] = param_group[ori_key]
gate_group["name"] = "gate"
gate_group["gate"] = True
for ori_key in param_group.keys():
if ori_key != "name":
if ori_key == "params":
gate_group[ori_key] = []
else:
gate_group[ori_key] = param_group[ori_key]
norm_group["name"] = "norm"
norm_group["norm"] = True
for ori_key in param_group.keys():
if ori_key != "name":
if ori_key == "params":
norm_group[ori_key] = []
else:
norm_group[ori_key] = param_group[ori_key]
# Assign param
norm_params = []
gate_params = []
for param_group in param_groups:
new_params = []
for param in param_group["params"]:
if is_moe_param(param):
group_moe[param.group_name]["params"].append(param)
# param_group['params'].remove(param)
group_moe[param_group["name"]][param.group_name]["params"].append(param)
elif is_norm_param(param):
norm_params.append(param)
elif is_gate_param(param):
gate_params.append(param)
else:
new_params.append(param)
param_group["params"] = new_params
norm_group["params"] = norm_params
gate_group["params"] = gate_params
param_groups.append(norm_group)
param_groups.append(gate_group)
# Flatten the moe groups
if max_group_size is not None:
for _, v1 in group_moe.items():
cur_group = []
all_groups = []
size_of_cur_group = 0
for param in v1["params"]:
cur_group.append(param)
size_of_cur_group += param.numel()
if size_of_cur_group > max_group_size:
for _, v in group_moe.items():
for _, v1 in v.items():
cur_group = []
all_groups = []
size_of_cur_group = 0
for param in v1["params"]:
if size_of_cur_group + param.numel() <= max_group_size:
cur_group.append(param)
size_of_cur_group += param.numel()
else:
all_groups.append(cur_group)
cur_group = [param]
size_of_cur_group = param.numel()
if cur_group:
all_groups.append(cur_group)
cur_group = []
size_of_cur_group = 0
if cur_group:
all_groups.append(cur_group)
for group in all_groups:
new_dict = {}
for key, val in v1.items():
if key != "params":
new_dict[key] = val
new_dict["params"] = group
param_groups.append(new_dict)
for group in all_groups:
new_dict = {}
for key, val in v1.items():
if key != "params":
new_dict[key] = val
new_dict["params"] = group
param_groups.append(new_dict)
else:
for _, v1 in group_moe.items():
param_groups.append(v1)
for _, v in group_moe.items():
for _, v1 in v.items():
param_groups.append(v1)
return tuple(param_groups)

View File

@ -167,11 +167,6 @@ def _top_idx(source, k):
return torch.topk(source, k=k, dim=0)[1]
@torch.jit.script
def _one_hot_to_float(x, num_classes):
return F.one_hot(x, num_classes=num_classes).float()
def top1gating(
logits: Tensor,
capacity_factor: float,
@ -210,7 +205,7 @@ def top1gating(
# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
ce = torch.mean(mask1.type_as(logits), dim=0)
l_aux = torch.sum(me * ce) * num_experts
# Random Token Selection
@ -244,10 +239,10 @@ def top1gating(
locations1_s = torch.sum(locations1 * mask1, dim=1)
# Normalize gate probabilities
mask1_float = mask1.float()
mask1_float = mask1.type_as(logits)
gates = gates * mask1_float
locations1_sc = _one_hot_to_float(locations1_s, capacity)
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).type_as(logits)
combine_weights = einsum("se,sc->sec", gates, locations1_sc)
dispatch_mask = combine_weights.bool()
@ -271,7 +266,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# Replace top-expert with min value
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), torch.finfo(logits.dtype).min)
indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
@ -286,7 +281,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
ce = torch.mean(mask1.type_as(logits), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts
# Remove locations outside capacity from mask
@ -298,8 +293,8 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
locations2_s = torch.sum(locations2 * mask2, dim=1)
# Normalize gate probabilities
mask1_float = mask1.float()
mask2_float = mask2.float()
mask1_float = mask1.type_as(logits)
mask2_float = mask2.type_as(logits)
gates1_s = einsum("se,se->s", gates, mask1_float)
gates2_s = einsum("se,se->s", gates, mask2_float)
denom_s = gates1_s + gates2_s
@ -311,8 +306,8 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
# Calculate combine_weights and dispatch_mask
gates1 = einsum("s,se->se", gates1_s, mask1_float)
gates2 = einsum("s,se->se", gates2_s, mask2_float)
locations1_sc = _one_hot_to_float(locations1_s, capacity)
locations2_sc = _one_hot_to_float(locations2_s, capacity)
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).type_as(logits)
locations2_sc = F.one_hot(locations2_s, num_classes=capacity).type_as(logits)
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
combine_weights = combine1_sec + combine2_sec
@ -357,7 +352,7 @@ class TopKGate(Module):
if k not in (1, 2):
raise ValueError("Only top-1 and top-2 gatings are supported.")
# Deepspeed's mechisms, alway use fp32
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
self.k = k
self.capacity_factor = capacity_factor
self.eval_capacity_factor = eval_capacity_factor
@ -368,6 +363,9 @@ class TopKGate(Module):
self.drop_tokens = drop_tokens
self.use_rts = use_rts
for param in self.wg.parameters():
param.is_gate = True
def forward(
self, inputs: torch.Tensor, used_token: torch.Tensor = None
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
@ -375,13 +373,10 @@ class TopKGate(Module):
if self.wall_clock_breakdown:
timer("TopKGate").start()
if self.wg.weight.dtype != torch.float32: # TODO can we change it to fp16
self.wg = self.wg.float()
inputs_fp32 = inputs.float()
# input jittering
if self.noisy_gate_policy == "Jitter" and self.training:
inputs_fp32 = multiplicative_jitter(inputs_fp32, device=inputs.device)
logits = self.wg(inputs_fp32)
inputs = multiplicative_jitter(inputs, device=inputs.device)
logits = self.wg(inputs)
if self.k == 1:
gate_output = top1gating(

View File

@ -308,6 +308,12 @@ class HybridZeroOptimizer(BaseOptimizer):
def _is_moe_group(self, param_group):
return "moe" in param_group.keys() and param_group["moe"]
def _is_norm_group(self, param_group):
return "norm" in param_group.keys() and param_group["norm"]
def _is_gate_group(self, param_group):
return "gate" in param_group.keys() and param_group["gate"]
def _attach_reduction_hook(self):
# we iterate over the fp16 params
# on each param, we register a hook to its AccumulateGrad object
@ -688,6 +694,22 @@ class HybridZeroOptimizer(BaseOptimizer):
param_shape == flat_fp32_avg_grads.shape
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
# Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients.
# Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors.
if self._is_norm_group(self.optim.param_groups[group_id]):
dist.all_reduce(
flat_fp32_avg_grads,
op=dist.ReduceOp.AVG,
group=gpc.get_group(ParallelMode.TENSOR),
)
if self._is_gate_group(self.optim.param_groups[group_id]):
dist.all_reduce(
flat_fp32_avg_grads,
op=dist.ReduceOp.AVG,
group=gpc.get_group(ParallelMode.TENSOR),
)
single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)