diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 3fb90a8..9ab410c 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -123,6 +123,11 @@ class PackedFlashBaseLayer1D(nn.Module): self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + for param in self.norm1.parameters(): + param.is_norm = True + for param in self.norm2.parameters(): + param.is_norm = True + self.num_experts = num_experts self.moe_gate_k = moe_gate_k self.moe_capacity_factor = moe_capacity_factor diff --git a/internlm/model/moe.py b/internlm/model/moe.py index 6d02779..e102df9 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -38,6 +38,18 @@ def is_moe_param(param: torch.Tensor) -> bool: return False +def is_gate_param(param: torch.Tensor) -> bool: + if hasattr(param, "is_gate") and param.is_gate: + return True + return False + + +def is_norm_param(param: torch.Tensor) -> bool: + if hasattr(param, "is_norm") and param.is_norm: + return True + return False + + class MoE(torch.nn.Module): """Initialize an MoE layer. @@ -205,55 +217,85 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic data_parallel_group_names.add(param.group_name) data_parallel_group_names = list(data_parallel_group_names) group_moe = {} + gate_group = {} + norm_group = {} # Create the param MoE groups, leave param assign to next step for param_group in param_groups: + group_moe[param_group["name"]] = {} for key in data_parallel_group_names: - group_moe[key] = {} - group_moe[key]["name"] = key - group_moe[key]["moe"] = True + group_moe[param_group["name"]][key] = {} + group_moe[param_group["name"]][key]["name"] = key + group_moe[param_group["name"]][key]["moe"] = True for ori_key in param_group.keys(): if ori_key != "name": if ori_key == "params": - group_moe[key][ori_key] = [] + group_moe[param_group["name"]][key][ori_key] = [] else: - group_moe[key][ori_key] = param_group[ori_key] - + group_moe[param_group["name"]][key][ori_key] = param_group[ori_key] + gate_group["name"] = "gate" + gate_group["gate"] = True + for ori_key in param_group.keys(): + if ori_key != "name": + if ori_key == "params": + gate_group[ori_key] = [] + else: + gate_group[ori_key] = param_group[ori_key] + norm_group["name"] = "norm" + norm_group["norm"] = True + for ori_key in param_group.keys(): + if ori_key != "name": + if ori_key == "params": + norm_group[ori_key] = [] + else: + norm_group[ori_key] = param_group[ori_key] # Assign param + norm_params = [] + gate_params = [] for param_group in param_groups: new_params = [] for param in param_group["params"]: if is_moe_param(param): - group_moe[param.group_name]["params"].append(param) - # param_group['params'].remove(param) + group_moe[param_group["name"]][param.group_name]["params"].append(param) + elif is_norm_param(param): + norm_params.append(param) + elif is_gate_param(param): + gate_params.append(param) else: new_params.append(param) param_group["params"] = new_params + norm_group["params"] = norm_params + gate_group["params"] = gate_params + param_groups.append(norm_group) + param_groups.append(gate_group) # Flatten the moe groups if max_group_size is not None: - for _, v1 in group_moe.items(): - cur_group = [] - all_groups = [] - size_of_cur_group = 0 - for param in v1["params"]: - cur_group.append(param) - size_of_cur_group += param.numel() - if size_of_cur_group > max_group_size: + for _, v in group_moe.items(): + for _, v1 in v.items(): + cur_group = [] + all_groups = [] + size_of_cur_group = 0 + for param in v1["params"]: + if size_of_cur_group + param.numel() <= max_group_size: + cur_group.append(param) + size_of_cur_group += param.numel() + else: + all_groups.append(cur_group) + cur_group = [param] + size_of_cur_group = param.numel() + if cur_group: all_groups.append(cur_group) - cur_group = [] - size_of_cur_group = 0 - if cur_group: - all_groups.append(cur_group) - for group in all_groups: - new_dict = {} - for key, val in v1.items(): - if key != "params": - new_dict[key] = val - new_dict["params"] = group - param_groups.append(new_dict) + for group in all_groups: + new_dict = {} + for key, val in v1.items(): + if key != "params": + new_dict[key] = val + new_dict["params"] = group + param_groups.append(new_dict) else: - for _, v1 in group_moe.items(): - param_groups.append(v1) + for _, v in group_moe.items(): + for _, v1 in v.items(): + param_groups.append(v1) return tuple(param_groups) diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 66b41d0..6a6c511 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -167,11 +167,6 @@ def _top_idx(source, k): return torch.topk(source, k=k, dim=0)[1] -@torch.jit.script -def _one_hot_to_float(x, num_classes): - return F.one_hot(x, num_classes=num_classes).float() - - def top1gating( logits: Tensor, capacity_factor: float, @@ -210,7 +205,7 @@ def top1gating( # Compute l_aux me = torch.mean(gates, dim=0) - ce = torch.mean(mask1.float(), dim=0) + ce = torch.mean(mask1.type_as(logits), dim=0) l_aux = torch.sum(me * ce) * num_experts # Random Token Selection @@ -244,10 +239,10 @@ def top1gating( locations1_s = torch.sum(locations1 * mask1, dim=1) # Normalize gate probabilities - mask1_float = mask1.float() + mask1_float = mask1.type_as(logits) gates = gates * mask1_float - locations1_sc = _one_hot_to_float(locations1_s, capacity) + locations1_sc = F.one_hot(locations1_s, num_classes=capacity).type_as(logits) combine_weights = einsum("se,sc->sec", gates, locations1_sc) dispatch_mask = combine_weights.bool() @@ -271,7 +266,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) # Replace top-expert with min value - logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf")) + logits_except1 = logits_w_noise.masked_fill(mask1.bool(), torch.finfo(logits.dtype).min) indices2_s = torch.argmax(logits_except1, dim=1) mask2 = F.one_hot(indices2_s, num_classes=num_experts) @@ -286,7 +281,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup # Compute l_aux me = torch.mean(gates, dim=0) - ce = torch.mean(mask1.float(), dim=0) + ce = torch.mean(mask1.type_as(logits), dim=0) l_aux = torch.mean(me * ce) * num_experts * num_experts # Remove locations outside capacity from mask @@ -298,8 +293,8 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup locations2_s = torch.sum(locations2 * mask2, dim=1) # Normalize gate probabilities - mask1_float = mask1.float() - mask2_float = mask2.float() + mask1_float = mask1.type_as(logits) + mask2_float = mask2.type_as(logits) gates1_s = einsum("se,se->s", gates, mask1_float) gates2_s = einsum("se,se->s", gates, mask2_float) denom_s = gates1_s + gates2_s @@ -311,8 +306,8 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup # Calculate combine_weights and dispatch_mask gates1 = einsum("s,se->se", gates1_s, mask1_float) gates2 = einsum("s,se->se", gates2_s, mask2_float) - locations1_sc = _one_hot_to_float(locations1_s, capacity) - locations2_sc = _one_hot_to_float(locations2_s, capacity) + locations1_sc = F.one_hot(locations1_s, num_classes=capacity).type_as(logits) + locations2_sc = F.one_hot(locations2_s, num_classes=capacity).type_as(logits) combine1_sec = einsum("se,sc->sec", gates1, locations1_sc) combine2_sec = einsum("se,sc->sec", gates2, locations2_sc) combine_weights = combine1_sec + combine2_sec @@ -357,7 +352,7 @@ class TopKGate(Module): if k not in (1, 2): raise ValueError("Only top-1 and top-2 gatings are supported.") # Deepspeed's mechisms, alway use fp32 - self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() + self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) self.k = k self.capacity_factor = capacity_factor self.eval_capacity_factor = eval_capacity_factor @@ -368,6 +363,9 @@ class TopKGate(Module): self.drop_tokens = drop_tokens self.use_rts = use_rts + for param in self.wg.parameters(): + param.is_gate = True + def forward( self, inputs: torch.Tensor, used_token: torch.Tensor = None ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore @@ -375,13 +373,10 @@ class TopKGate(Module): if self.wall_clock_breakdown: timer("TopKGate").start() - if self.wg.weight.dtype != torch.float32: # TODO can we change it to fp16 - self.wg = self.wg.float() - inputs_fp32 = inputs.float() # input jittering if self.noisy_gate_policy == "Jitter" and self.training: - inputs_fp32 = multiplicative_jitter(inputs_fp32, device=inputs.device) - logits = self.wg(inputs_fp32) + inputs = multiplicative_jitter(inputs, device=inputs.device) + logits = self.wg(inputs) if self.k == 1: gate_output = top1gating( diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 66ee037..80aaa7b 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -308,6 +308,12 @@ class HybridZeroOptimizer(BaseOptimizer): def _is_moe_group(self, param_group): return "moe" in param_group.keys() and param_group["moe"] + def _is_norm_group(self, param_group): + return "norm" in param_group.keys() and param_group["norm"] + + def _is_gate_group(self, param_group): + return "gate" in param_group.keys() and param_group["gate"] + def _attach_reduction_hook(self): # we iterate over the fp16 params # on each param, we register a hook to its AccumulateGrad object @@ -688,6 +694,22 @@ class HybridZeroOptimizer(BaseOptimizer): param_shape == flat_fp32_avg_grads.shape ), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}" + # Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients. + # Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors. + if self._is_norm_group(self.optim.param_groups[group_id]): + dist.all_reduce( + flat_fp32_avg_grads, + op=dist.ReduceOp.AVG, + group=gpc.get_group(ParallelMode.TENSOR), + ) + + if self._is_gate_group(self.optim.param_groups[group_id]): + dist.all_reduce( + flat_fp32_avg_grads, + op=dist.ReduceOp.AVG, + group=gpc.get_group(ParallelMode.TENSOR), + ) + single_grad_partition_groups.append(flat_fp32_avg_grads) device = self._fp32_flat_param_groups_of_current_rank[group_id].device self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)