mirror of https://github.com/InternLM/InternLM
fix(moe): remove norm&gate force sync (#448)
* add zero broadcast_sync * delete old sync logic * fix merged error * refactor code * remove some unused function (is norm/gate group)pull/464/head
parent
f77f376fd6
commit
21624f6f81
|
@ -293,10 +293,6 @@ def args_sanity_check():
|
||||||
model._add_item("moe_use_residual", False)
|
model._add_item("moe_use_residual", False)
|
||||||
if "moe_gate_k" not in model:
|
if "moe_gate_k" not in model:
|
||||||
model._add_item("moe_gate_k", 2)
|
model._add_item("moe_gate_k", 2)
|
||||||
assert not (
|
|
||||||
gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp
|
|
||||||
), "FSDP does not support num_experts > 1"
|
|
||||||
|
|
||||||
# process the parallel config
|
# process the parallel config
|
||||||
if "sequence_parallel" not in gpc.config.parallel:
|
if "sequence_parallel" not in gpc.config.parallel:
|
||||||
gpc.config.parallel._add_item("sequence_parallel", False)
|
gpc.config.parallel._add_item("sequence_parallel", False)
|
||||||
|
@ -345,11 +341,13 @@ def args_sanity_check():
|
||||||
gpc.config.loss._add_item("moe_loss_coeff", 1.0)
|
gpc.config.loss._add_item("moe_loss_coeff", 1.0)
|
||||||
|
|
||||||
# moe not support overlap and zero1.5 for now
|
# moe not support overlap and zero1.5 for now
|
||||||
if hasattr(gpc.config.model, "num_experts"):
|
if gpc.config.model.get("num_experts", 1) > 1:
|
||||||
|
assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support num_experts > 1"
|
||||||
assert (
|
assert (
|
||||||
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
|
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
|
||||||
), "not support overlap and moe at the same time"
|
), "not support overlap and moe at the same time"
|
||||||
assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=dict(size=-1,...) can fix this"
|
assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=dict(size=-1,...) can fix this"
|
||||||
|
assert not gpc.config.parallel.sequence_parallel, "moe not support sequence parallel for now"
|
||||||
|
|
||||||
|
|
||||||
def launch(
|
def launch(
|
||||||
|
@ -413,7 +411,7 @@ def launch(
|
||||||
f"data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, "
|
f"data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, "
|
||||||
f"tensor parallel size: {gpc.tensor_parallel_size}",
|
f"tensor parallel size: {gpc.tensor_parallel_size}",
|
||||||
)
|
)
|
||||||
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
|
if gpc.config.model.get("num_experts", 1) > 1:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Creating MoE with num_experts: {gpc.config.model.num_experts} | "
|
f"Creating MoE with num_experts: {gpc.config.model.num_experts} | "
|
||||||
f"expert parallel size: {gpc.expert_parallel_size} | "
|
f"expert parallel size: {gpc.expert_parallel_size} | "
|
||||||
|
|
|
@ -127,11 +127,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
|
|
||||||
for param in self.norm1.parameters():
|
|
||||||
param.is_norm = True
|
|
||||||
for param in self.norm2.parameters():
|
|
||||||
param.is_norm = True
|
|
||||||
set_fp32_attr_to_module(self.norm1)
|
set_fp32_attr_to_module(self.norm1)
|
||||||
set_fp32_attr_to_module(self.norm2)
|
set_fp32_attr_to_module(self.norm2)
|
||||||
|
|
||||||
|
|
|
@ -215,18 +215,6 @@ def is_moe_param(param: torch.Tensor) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_gate_param(param: torch.Tensor) -> bool:
|
|
||||||
if hasattr(param, "is_gate") and param.is_gate:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_norm_param(param: torch.Tensor) -> bool:
|
|
||||||
if hasattr(param, "is_norm") and param.is_norm:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def Silu(w1_o, w2_o):
|
def Silu(w1_o, w2_o):
|
||||||
return F.silu(w1_o) * w2_o
|
return F.silu(w1_o) * w2_o
|
||||||
|
|
||||||
|
|
|
@ -352,9 +352,6 @@ class TopKGate(Module):
|
||||||
self.drop_tokens = drop_tokens
|
self.drop_tokens = drop_tokens
|
||||||
self.use_rts = use_rts
|
self.use_rts = use_rts
|
||||||
|
|
||||||
for param in self.wg.parameters():
|
|
||||||
param.is_gate = True
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, inputs: torch.Tensor, used_token: torch.Tensor = None
|
self, inputs: torch.Tensor, used_token: torch.Tensor = None
|
||||||
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
|
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
|
||||||
|
|
|
@ -274,12 +274,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
def _is_moe_group(self, param_group):
|
def _is_moe_group(self, param_group):
|
||||||
return "moe" in param_group.keys() and param_group["moe"]
|
return "moe" in param_group.keys() and param_group["moe"]
|
||||||
|
|
||||||
def _is_norm_group(self, param_group):
|
|
||||||
return "norm" in param_group.keys() and param_group["norm"]
|
|
||||||
|
|
||||||
def _is_gate_group(self, param_group):
|
|
||||||
return "gate" in param_group.keys() and param_group["gate"]
|
|
||||||
|
|
||||||
# TODO check expert dp is correct when enable moe and overlap both
|
# TODO check expert dp is correct when enable moe and overlap both
|
||||||
def _attach_reduction_hook(self):
|
def _attach_reduction_hook(self):
|
||||||
# we iterate over the fp16 params
|
# we iterate over the fp16 params
|
||||||
|
@ -566,7 +560,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
last_stage=last_stage,
|
last_stage=last_stage,
|
||||||
previous_param_norms=previous_param_norms,
|
previous_param_norms=previous_param_norms,
|
||||||
zero_mode=self._broadcast_parallel_mode[group_id],
|
zero_mode=self._broadcast_parallel_mode[group_id],
|
||||||
is_moe_group=self._is_moe_group(self.optim.param_groups[group_id]),
|
|
||||||
)
|
)
|
||||||
return total_param_norms
|
return total_param_norms
|
||||||
|
|
||||||
|
@ -589,7 +582,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
last_stage=last_stage,
|
last_stage=last_stage,
|
||||||
previous_zero_grad_count=previous_zero_grad_count,
|
previous_zero_grad_count=previous_zero_grad_count,
|
||||||
zero_mode=self._broadcast_parallel_mode[group_id],
|
zero_mode=self._broadcast_parallel_mode[group_id],
|
||||||
is_moe_group=self._is_moe_group(self.optim.param_groups[group_id]),
|
|
||||||
)
|
)
|
||||||
return total_zero_grad_count
|
return total_zero_grad_count
|
||||||
|
|
||||||
|
@ -675,16 +667,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
total_param_zero_grad_count[group_name],
|
total_param_zero_grad_count[group_name],
|
||||||
) = compute_layer_zero_grad_count(zero_grad_count)
|
) = compute_layer_zero_grad_count(zero_grad_count)
|
||||||
|
|
||||||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced
|
|
||||||
# during allreduce
|
|
||||||
if self._is_moe_group(self.optim.param_groups[group_id]):
|
|
||||||
# model and zero have been reduced!!!
|
|
||||||
pg = gpc.get_group(ParallelMode.EXPERT)
|
|
||||||
scaled_norm = total_norms[group_name] * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT))
|
|
||||||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
|
||||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
|
||||||
total_norms[group_name] = scaled_norm_tensor.item()
|
|
||||||
|
|
||||||
timer("sync_grad").start()
|
timer("sync_grad").start()
|
||||||
self._sync_grad()
|
self._sync_grad()
|
||||||
timer("sync_grad").stop()
|
timer("sync_grad").stop()
|
||||||
|
@ -767,19 +749,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
param_shape == flat_fp32_avg_grads.shape
|
param_shape == flat_fp32_avg_grads.shape
|
||||||
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
|
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
|
||||||
|
|
||||||
# Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients.
|
|
||||||
# Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors.
|
|
||||||
is_tp_sync_groups = (
|
|
||||||
self._is_norm_group(self.optim.param_groups[group_id]),
|
|
||||||
self._is_gate_group(self.optim.param_groups[group_id]),
|
|
||||||
)
|
|
||||||
if any(is_tp_sync_groups):
|
|
||||||
dist.all_reduce(
|
|
||||||
flat_fp32_avg_grads,
|
|
||||||
op=dist.ReduceOp.AVG,
|
|
||||||
group=gpc.get_group(ParallelMode.TENSOR),
|
|
||||||
)
|
|
||||||
|
|
||||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||||
|
|
|
@ -262,7 +262,12 @@ def reduce_grads(gradients, parameters, fine_grained=False):
|
||||||
|
|
||||||
|
|
||||||
def compute_norm(
|
def compute_norm(
|
||||||
gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, zero_mode=ParallelMode.ZERO1
|
gradients,
|
||||||
|
parameters,
|
||||||
|
last_stage=False,
|
||||||
|
previous_norm=None,
|
||||||
|
norm_type=2,
|
||||||
|
zero_mode=ParallelMode.ZERO1,
|
||||||
):
|
):
|
||||||
"""Get the norm
|
"""Get the norm
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -335,6 +340,15 @@ def compute_norm(
|
||||||
if torch.is_tensor(total_norm):
|
if torch.is_tensor(total_norm):
|
||||||
total_norm = total_norm.item()
|
total_norm = total_norm.item()
|
||||||
|
|
||||||
|
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
|
||||||
|
# model and zero have been reduced!!!
|
||||||
|
if zero_mode == ParallelMode.EXPERT_DATA:
|
||||||
|
pg = gpc.get_group(ParallelMode.EXPERT)
|
||||||
|
scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
||||||
|
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||||
|
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||||
|
total_norm = scaled_norm_tensor.item()
|
||||||
|
|
||||||
# Scale.
|
# Scale.
|
||||||
if total_norm == float("inf") or total_norm == -float("inf"):
|
if total_norm == float("inf") or total_norm == -float("inf"):
|
||||||
total_norm = -1
|
total_norm = -1
|
||||||
|
@ -353,7 +367,6 @@ def compute_param_metric(
|
||||||
previous_param_metrics=None,
|
previous_param_metrics=None,
|
||||||
norm_type=2,
|
norm_type=2,
|
||||||
zero_mode=ParallelMode.ZERO1,
|
zero_mode=ParallelMode.ZERO1,
|
||||||
is_moe_group=False,
|
|
||||||
):
|
):
|
||||||
"""Get the metrics of params
|
"""Get the metrics of params
|
||||||
Argumemts:
|
Argumemts:
|
||||||
|
@ -424,7 +437,7 @@ def compute_param_metric(
|
||||||
total_metrics[param_name] += param_metric
|
total_metrics[param_name] += param_metric
|
||||||
|
|
||||||
# moe
|
# moe
|
||||||
if is_moe_group:
|
if zero_mode == ParallelMode.EXPERT_DATA:
|
||||||
pg = gpc.get_group(ParallelMode.EXPERT)
|
pg = gpc.get_group(ParallelMode.EXPERT)
|
||||||
total_metric_values = list(total_metrics.values())
|
total_metric_values = list(total_metrics.values())
|
||||||
if isinstance(total_metric_values[0], torch.Tensor):
|
if isinstance(total_metric_values[0], torch.Tensor):
|
||||||
|
@ -463,7 +476,6 @@ def compute_param_norm(
|
||||||
previous_param_norms=None,
|
previous_param_norms=None,
|
||||||
norm_type=2,
|
norm_type=2,
|
||||||
zero_mode=ParallelMode.ZERO1,
|
zero_mode=ParallelMode.ZERO1,
|
||||||
is_moe_group=False,
|
|
||||||
):
|
):
|
||||||
"""Get the norm of params
|
"""Get the norm of params
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -484,7 +496,6 @@ def compute_param_norm(
|
||||||
previous_param_metrics=previous_param_norms,
|
previous_param_metrics=previous_param_norms,
|
||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
zero_mode=zero_mode,
|
zero_mode=zero_mode,
|
||||||
is_moe_group=is_moe_group,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -494,7 +505,6 @@ def compute_zero_grad_count(
|
||||||
last_stage=False,
|
last_stage=False,
|
||||||
previous_zero_grad_count=None,
|
previous_zero_grad_count=None,
|
||||||
zero_mode=ParallelMode.ZERO1,
|
zero_mode=ParallelMode.ZERO1,
|
||||||
is_moe_group=False,
|
|
||||||
):
|
):
|
||||||
"""Get the count of zero gradient for each parameters
|
"""Get the count of zero gradient for each parameters
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -512,7 +522,6 @@ def compute_zero_grad_count(
|
||||||
last_stage=last_stage,
|
last_stage=last_stage,
|
||||||
previous_param_metrics=previous_zero_grad_count,
|
previous_param_metrics=previous_zero_grad_count,
|
||||||
zero_mode=zero_mode,
|
zero_mode=zero_mode,
|
||||||
is_moe_group=is_moe_group,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
|
|
||||||
from internlm.core.context.parallel_context import ParallelMode
|
from internlm.core.context.parallel_context import ParallelMode
|
||||||
from internlm.core.context.parallel_context import global_context as gpc
|
from internlm.core.context.parallel_context import global_context as gpc
|
||||||
from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param
|
from internlm.model.utils import is_moe_param
|
||||||
|
|
||||||
|
|
||||||
def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]:
|
def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]:
|
||||||
|
@ -39,10 +39,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
# create new groups for fp32, norm, moe gate and moe expert
|
# create new groups for fp32, norm, moe gate and moe expert
|
||||||
new_groups = {}
|
new_groups = {}
|
||||||
new_groups["fp32"] = {"name": "fp32", "params": [], "dp_mode": ParallelMode.DATA}
|
new_groups["fp32"] = {"name": "fp32", "params": [], "dp_mode": ParallelMode.DATA}
|
||||||
if gpc.config.model.get("num_experts", 0) > 1:
|
if gpc.config.model.get("num_experts", 1) > 1:
|
||||||
# norm and gate are special group to force sync (when enable MoE).
|
|
||||||
for key in ["gate", "norm"]:
|
|
||||||
new_groups[key] = {"name": key, key: True, "params": [], "dp_mode": ParallelMode.DATA}
|
|
||||||
for key in gpc.expert_parallel_group_names:
|
for key in gpc.expert_parallel_group_names:
|
||||||
new_groups[key] = {"name": key, "moe": True, "params": [], "dp_mode": ParallelMode.EXPERT_DATA}
|
new_groups[key] = {"name": key, "moe": True, "params": [], "dp_mode": ParallelMode.EXPERT_DATA}
|
||||||
|
|
||||||
|
@ -58,12 +55,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
# first split the norm and gate groups, which are special case to force sync (when enable MoE),
|
# first split the norm and gate groups, which are special case to force sync (when enable MoE),
|
||||||
# then fp32 group and the moe group.
|
# then fp32 group and the moe group.
|
||||||
for param in pgroup["params"]:
|
for param in pgroup["params"]:
|
||||||
if gpc.config.model.get("num_experts", 0) > 1 and is_norm_param(param):
|
if param.dtype == torch.float32:
|
||||||
new_groups["norm"]["params"].append(param)
|
|
||||||
# gate param means MoE is enabled
|
|
||||||
elif is_gate_param(param):
|
|
||||||
new_groups["gate"]["params"].append(param)
|
|
||||||
elif param.dtype == torch.float32:
|
|
||||||
new_groups["fp32"]["params"].append(param)
|
new_groups["fp32"]["params"].append(param)
|
||||||
# moe param means MoE is enabled
|
# moe param means MoE is enabled
|
||||||
elif is_moe_param(param):
|
elif is_moe_param(param):
|
||||||
|
|
Loading…
Reference in New Issue