From b9202b12bcc23959c2c3b9539c363a8f323e23e8 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 1 Sep 2023 13:24:46 +0800 Subject: [PATCH] feat(utils/writer.py): support writer add_scalars for writing dict data (#257) * feat(utils/writer.py): support writer add_scalars interface for writing dict data * feat(hybrid_zero_optim.py): change grad_norm_groups list to dict --- .../solver/optimizer/hybrid_zero_optim.py | 35 ++++++++++-------- internlm/train/training_internlm.py | 36 +++++++++++-------- internlm/utils/writer.py | 8 +++++ train.py | 5 ++- 4 files changed, 52 insertions(+), 32 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 8bdeccf..63d2bfa 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -497,6 +497,7 @@ class HybridZeroOptimizer(BaseOptimizer): grads = [self.padding_grad] params = [self.padding_tensor] + norm = 0 if self._clip_grad_norm > 0: # this norm is before scaling, it will be very large norm = compute_norm( @@ -542,15 +543,15 @@ class HybridZeroOptimizer(BaseOptimizer): self._param_store.clear_grads_of_previous_reduced_params() # compute norm for gradients in the last bucket - total_norms = [] + total_norms = {} for group_id in range(self.num_param_groups): - total_norms.append( - self._compute_norm_with_stage( - group_id=group_id, - last_bucket=True, - last_stage=True, - previous_norm=groups_norms[group_id], - ) + group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default" + group_name = f"{group_id}_{group_name}" + total_norms[group_name] = self._compute_norm_with_stage( + group_id=group_id, + last_bucket=True, + last_stage=True, + previous_norm=groups_norms[group_id], ) timer("sync_grad").start() @@ -569,7 +570,7 @@ class HybridZeroOptimizer(BaseOptimizer): # found_inf = self._check_overflow() # Because you may encounter inf when computing norm - if -1 in norms: + if -1 in norms.values(): found_inf = True loss_scale = float(self.loss_scale.item()) # backup @@ -617,15 +618,17 @@ class HybridZeroOptimizer(BaseOptimizer): # unscale and clip grads # get the global norm - global_norm_groups = [] + global_norm_groups = {} if self._clip_grad_norm > 0: - for norm in norms: - global_norm_groups.append(norm**0.5) + for group_name, norm in norms.items(): + global_norm_groups[group_name] = norm**0.5 # the following operations are performed only on the rank to which parameters are assigned. if gpc.config.model.dtype is not torch.float32: - if len(single_grad_partition_groups) != 0: - self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale) + if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0: + self._unscale_and_clip_grads( + single_grad_partition_groups, list(global_norm_groups.values()), loss_scale + ) # update the parameters timer("step").start() @@ -652,7 +655,9 @@ class HybridZeroOptimizer(BaseOptimizer): # update gradients may not be needed here, because the sync_params function is used in initialization, # so synchronization is maintained - return True, [global_norm / loss_scale for global_norm in global_norm_groups] + for group_name, global_norm in global_norm_groups.items(): + global_norm_groups[group_name] = global_norm / loss_scale + return True, global_norm_groups def broadcast_params(self): handles = [] diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index bab56f1..9c2ded0 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -389,23 +389,31 @@ def record_current_batch_training_metrics( line = "" for key, value in infos.items(): line += f"{key}={value} " - writer.add_scalar(key=key, value=value, step=train_state.step_count) + if isinstance(value, dict): + writer.add_scalars(key=key, value=value, step=train_state.step_count) + else: + writer.add_scalar(key=key, value=value, step=train_state.step_count) if update_panel: + # metrics shown with dashboard panels + panel_metrics = { + "step": batch_count, + "lr": lr, + "num_consumed_tokens": train_state.num_consumed_tokens, + "loss": loss.item(), + "flops": tflops, + "tgs": tk_per_gpu, + "acc": acc_perplex["acc"], + "perplexity": acc_perplex["perplexity"], + "fwd_bwd_time": fwd_bwd_time, + } + for norm_key, norm_value in grad_norm.items(): + panel_metrics[norm_key] = norm_value + logger.info( - line, - extra={ - "step": batch_count, - "lr": lr, - "num_consumed_tokens": train_state.num_consumed_tokens, - "grad_norm": grad_norm, - "loss": loss.item(), - "flops": tflops, - "tgs": tk_per_gpu, - "acc": acc_perplex["acc"], - "perplexity": acc_perplex["perplexity"], - "fwd_bwd_time": fwd_bwd_time, - }, + "{line}", + line=line, + extra=panel_metrics, ) else: logger.info(line) diff --git a/internlm/utils/writer.py b/internlm/utils/writer.py index 0997817..b519b95 100644 --- a/internlm/utils/writer.py +++ b/internlm/utils/writer.py @@ -134,6 +134,14 @@ class Writer: except Exception: traceback.print_exc() + def add_scalars(self, key, value, step): + try: + assert isinstance(value, dict) + if self.enable_tb and self.tb_writer is not None: + self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step) + except Exception: + traceback.print_exc() + def add_text(self, key, value, step): try: if self.enable_tb and self.tb_writer is not None: diff --git a/train.py b/train.py index de7cc7c..902f8c0 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,6 @@ import time import traceback from functools import partial -import numpy as np import torch import torch.distributed as dist @@ -236,7 +235,7 @@ def main(args): train_state.step_count += 1 else: train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. - if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case + if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case logger.warning(f"Warning: skip parameter update at step {batch_count}.") send_alert_message( address=gpc.config.alert_address, @@ -257,7 +256,7 @@ def main(args): trainer=trainer, start_time=start_time, loss=loss, - grad_norm=np.array(grad_norm_groups), + grad_norm=grad_norm_groups, metric=metric, update_panel=uniscale_logger is not None, )