diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 108bb58..491e59c 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -633,10 +633,10 @@ class HybridZeroOptimizer(BaseOptimizer): # compute norm for gradients in the last bucket total_norms = {} - total_param_norms = {} + total_param_grad_norms = {} + total_layer_grad_norms = {} total_param_zero_grad_count = {} total_layer_zero_grad_count = {} - total_layer_norms = {} for group_id in range(self.num_param_groups): group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default" group_name = f"{group_id}_{group_name}" @@ -653,7 +653,7 @@ class HybridZeroOptimizer(BaseOptimizer): last_stage=True, previous_param_norms=groups_param_norms[group_id], ) - total_layer_norms[group_name], total_param_norms[group_name] = compute_layer_norm( + total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm( param_norms=param_norms, loss_scale=self.loss_scale.item() ) if grad_profiling_config.get("zero_grad_profiling", False): @@ -674,8 +674,8 @@ class HybridZeroOptimizer(BaseOptimizer): state, global_norms = self._step(closure=closure, norms=total_norms) if grad_profiling_config.get("grad_norm_profiling", False): - global_norms["layer_norm"] = total_layer_norms - global_norms["param_norm"] = total_param_norms + global_norms["layer_grad_norm"] = total_layer_grad_norms + global_norms["param_grad_norm"] = total_param_grad_norms if grad_profiling_config.get("zero_grad_profiling", False): global_norms["layer_zero_grad"] = total_layer_zero_grad_count global_norms["param_zero_grad"] = total_param_zero_grad_count diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 7010ca5..f79606a 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -533,8 +533,8 @@ def record_current_batch_training_metrics( if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( "zero_grad_profiling", False ): - layer_metrics = ["layer_norm", "layer_zero_grad"] - param_metrics = ["param_norm", "param_zero_grad"] + layer_metrics = ["layer_grad_norm", "layer_zero_grad"] + param_metrics = ["param_grad_norm", "param_zero_grad"] layer_names = grad_profiling_config.get("layers", []) for layer_metric_name in layer_metrics: layer_metric = grad_norm.get(layer_metric_name, {})