diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 75b0331..e3f9608 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -527,7 +527,12 @@ class HybridZeroOptimizer(BaseOptimizer): params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket) total_param_norms = {} - if self._clip_grad_norm > 0 and len(params) > 0: + if len(params) == 0: + dtype = self.param_groups[group_id]["dtype"] + grads = [self.padding_grad.to(dtype)] + params = [self.padding_tensor.to(dtype)] + + if self._clip_grad_norm > 0: total_param_norms = compute_param_norm( grads, params, diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index a044420..9833309 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -345,14 +345,15 @@ def compute_param_norm( param_grads = {} for g, p in zip(gradients, parameters): - if p.param_name not in param_grads: - param_grads[p.param_name] = [] + param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding" + if param_name not in param_grads: + param_grads[param_name] = [] if ( gpc.is_initialized(ParallelMode.PIPELINE) and hasattr(p, "pipeline_shared_module_pg") and dist.get_rank(p.pipeline_shared_module_pg) == 0 ): # if shared between different pipe, only count o - param_grads[p.param_name].append(g.data.float()) + param_grads[param_name].append(g.data.float()) elif ( gpc.is_initialized(ParallelMode.PIPELINE) and hasattr(p, "pipeline_shared_module_pg") @@ -364,9 +365,9 @@ def compute_param_norm( and not is_model_parallel_parameter(p) and gpc.get_local_rank(ParallelMode.TENSOR) == 0 ): # if not used in each chunk, such as layernorm - param_grads[p.param_name].append(g.data.float()) + param_grads[param_name].append(g.data.float()) elif is_model_parallel_parameter(p): - param_grads[p.param_name].append(g.data.float()) + param_grads[param_name].append(g.data.float()) elif gpc.get_local_rank(ParallelMode.TENSOR) != 0: continue else: