test moe layer norm

pull/429/head
JiaoPL 2023-10-20 16:27:28 +08:00
parent 1ac7e4b489
commit 45fd0ec86b
2 changed files with 12 additions and 6 deletions

View File

@ -527,7 +527,12 @@ class HybridZeroOptimizer(BaseOptimizer):
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
total_param_norms = {}
if self._clip_grad_norm > 0 and len(params) > 0:
if len(params) == 0:
dtype = self.param_groups[group_id]["dtype"]
grads = [self.padding_grad.to(dtype)]
params = [self.padding_tensor.to(dtype)]
if self._clip_grad_norm > 0:
total_param_norms = compute_param_norm(
grads,
params,

View File

@ -345,14 +345,15 @@ def compute_param_norm(
param_grads = {}
for g, p in zip(gradients, parameters):
if p.param_name not in param_grads:
param_grads[p.param_name] = []
param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding"
if param_name not in param_grads:
param_grads[param_name] = []
if (
gpc.is_initialized(ParallelMode.PIPELINE)
and hasattr(p, "pipeline_shared_module_pg")
and dist.get_rank(p.pipeline_shared_module_pg) == 0
): # if shared between different pipe, only count o
param_grads[p.param_name].append(g.data.float())
param_grads[param_name].append(g.data.float())
elif (
gpc.is_initialized(ParallelMode.PIPELINE)
and hasattr(p, "pipeline_shared_module_pg")
@ -364,9 +365,9 @@ def compute_param_norm(
and not is_model_parallel_parameter(p)
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
): # if not used in each chunk, such as layernorm
param_grads[p.param_name].append(g.data.float())
param_grads[param_name].append(g.data.float())
elif is_model_parallel_parameter(p):
param_grads[p.param_name].append(g.data.float())
param_grads[param_name].append(g.data.float())
elif gpc.get_local_rank(ParallelMode.TENSOR) != 0:
continue
else: