From 641ee14bbfacc14320f6a20d2c4a68830d4c2c15 Mon Sep 17 00:00:00 2001 From: JiaoPL Date: Fri, 13 Oct 2023 12:07:58 +0800 Subject: [PATCH] update layer norm to tensorboard --- internlm/core/engine.py | 4 +- .../solver/optimizer/hybrid_zero_optim.py | 51 +++++++++++++++---- internlm/solver/optimizer/utils.py | 4 +- internlm/train/training_internlm.py | 6 +++ train.py | 3 +- 5 files changed, 53 insertions(+), 15 deletions(-) diff --git a/internlm/core/engine.py b/internlm/core/engine.py index a372b9e..f3e38e3 100644 --- a/internlm/core/engine.py +++ b/internlm/core/engine.py @@ -115,7 +115,7 @@ class Engine: self._all_reduce_gradients() self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm) - success, grad_norm = self.optimizer.step() + success, grad_norm, layer_grad_norm = self.optimizer.step() if success and self._lr_scheduler is not None: self._lr_scheduler.step() @@ -123,7 +123,7 @@ class Engine: if success and self._beta2_scheduler is not None: self._beta2_scheduler.step() - return success, grad_norm + return success, grad_norm, layer_grad_norm def train(self): """Sets the model to training mode.""" diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index d60d1a0..70064b8 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -564,27 +564,55 @@ class HybridZeroOptimizer(BaseOptimizer): total_layernorms[group_name] = self._compute_norm_with_stage( group_id=group_id, last_bucket=True, last_stage=True, previous_layer_norms=groups_layer_norms[group_id] ) - total_norms[group_name] = sum(total_layernorms[group_name].values()) # Need to allreduce(avg) the norms across different ranks because moe params will not be synced # during allreduce if self._is_moe_group(self.optim.param_groups[group_id]): # model and zero have been reduced!!! pg = gpc.get_group(ParallelMode.EXPERT) - scaled_norm = total_norms[group_name] * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT)) - scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) - dist.all_reduce(scaled_norm_tensor, group=pg) - total_norms[group_name] = scaled_norm_tensor.item() + # layer_norms allreduce + scaled_layer_norm = torch.cuda.FloatTensor( + list(total_layernorms[group_name].values()), device=get_current_device() + ) + scaled_layer_norm = scaled_layer_norm / float(gpc.get_world_size(ParallelMode.EXPERT)) + dist.all_reduce(scaled_layer_norm, group=pg) + for i, layer_name in enumerate(total_layernorms[group_name].keys()): + total_layernorms[group_name][layer_name] = scaled_layer_norm[i].item() + + # compute total_norms using the layer grad_norm + total_layer_norms_values = list(total_layernorms[group_name].values()) + # inf flag + if -1 in total_layer_norms_values: + total_norms[group_name] = -1 + # nan flag + elif -2 in total_layer_norms_values: + total_norms[group_name] = -2 + else: + total_norms[group_name] = sum(total_layer_norms_values) timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() - return self._step(closure=closure, norms=total_norms) + return self._step(closure=closure, norms=total_norms, layer_norms=total_layernorms) - def _step(self, closure=None, norms=None): + def _step(self, closure=None, norms=None, layer_norms=None): assert closure is None, "closure is not supported by step()" + def scale_layer_norm(layer_norms, loss_scale): + global_layer_norm_groups = {} + if layer_norms: + for group_name, layer_norm_dict in layer_norms.items(): + global_layer_norm_groups[group_name] = {} + for layer_name, norm in layer_norm_dict.items(): + # filter unknown + if layer_name == "unknown" and norm == 0: + continue + # handle inf (-1) and nan (-2) + if norm != -1 or norm != -2: + global_layer_norm_groups[group_name][layer_name] = norm**0.5 / loss_scale + return global_layer_norm_groups + # check for overflow found_inf = False found_nan = False @@ -603,6 +631,9 @@ class HybridZeroOptimizer(BaseOptimizer): if gpc.config.model.dtype is not torch.float32: self.grad_scaler.update(found_inf) + # scale layer norm + global_layer_norm_groups = scale_layer_norm(layer_norms, loss_scale) + # update loss scale if overflow occurs if found_inf: if gpc.is_rank_for_log(): @@ -613,7 +644,7 @@ class HybridZeroOptimizer(BaseOptimizer): ) self._grad_store._averaged_gradients = dict() self.zero_grad() - return False, norms + return False, norms, global_layer_norm_groups if found_nan: if gpc.is_rank_for_log(): @@ -624,7 +655,7 @@ class HybridZeroOptimizer(BaseOptimizer): ) self._grad_store._averaged_gradients = dict() self.zero_grad() - return False, norms + return False, norms, global_layer_norm_groups # copy the grad of fp16 param to fp32 param single_grad_partition_groups = [] @@ -711,7 +742,7 @@ class HybridZeroOptimizer(BaseOptimizer): # so synchronization is maintained for group_name, global_norm in global_norm_groups.items(): global_norm_groups[group_name] = global_norm / loss_scale - return True, global_norm_groups + return True, global_norm_groups, global_layer_norm_groups def broadcast_params(self): handles = [] diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 8775c5f..393e054 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -319,7 +319,7 @@ def compute_norm( dist.all_reduce(total_layer_norms_values, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL)) dist.all_reduce(total_layer_norms_values, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode)) - for idx in range(len(total_layer_norms_keys)): + for idx, layer_name in enumerate(total_layer_norms.keys()): layer_norm = total_layer_norms_values[idx] if torch.is_tensor(layer_norm): layer_norm = layer_norm.item() @@ -328,7 +328,7 @@ def compute_norm( if math.isnan(layer_norm): layer_norm = -2 - total_layer_norms[total_layer_norms_keys[idx]] = layer_norm + total_layer_norms[layer_name] = layer_norm return total_layer_norms diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 7af58dd..1451dc5 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -405,6 +405,7 @@ def record_current_batch_training_metrics( loss, moe_loss, grad_norm, + layer_grad_norm, metric, update_panel, ): @@ -526,6 +527,11 @@ def record_current_batch_training_metrics( else: writer.add_scalar(key=key, value=value, step=train_state.step_count) + # add layer grad norm + for key, value in layer_grad_norm.items(): + title = f"layer_grad_norm_group_{key}" + writer.add_scalars(key=title, value=value, step=train_state.step_count) + if gpc.config.monitor.alert.get("light_monitor_address", None) and batch_count % 50 == 0: send_heartbeat("train_metrics", infos) diff --git a/train.py b/train.py index 139bac1..0df51d6 100644 --- a/train.py +++ b/train.py @@ -240,7 +240,7 @@ def main(args): trainer_result = trainer.step() assert trainer_result is not None - success_update, grad_norm_groups = trainer_result + success_update, grad_norm_groups, layer_grad_norm_groups = trainer_result if success_update: # update parameters successfully train_state.step_count += 1 else: @@ -268,6 +268,7 @@ def main(args): loss=loss, moe_loss=moe_loss, grad_norm=grad_norm_groups, + layer_grad_norm=layer_grad_norm_groups, metric=metric, update_panel=uniscale_logger is not None, )