From 7c99e01ca71dc92dbc1e5b6ed42ad9ad8df97b13 Mon Sep 17 00:00:00 2001 From: jiaopenglong <44927264+JiaoPL@users.noreply.github.com> Date: Thu, 7 Sep 2023 21:49:05 +0800 Subject: [PATCH] fix(monitor): add alert switch and refactor monitor config (#285) * add monitor switch * add switch to light monitor * fix alert_address is empty * fix light monitor heartbeat * init light_monitor on rank_log only * add comments to the monitoring config * optimize config --- configs/7B_sft.py | 10 ++++++ internlm/initialize/launch.py | 36 +++++++++++++------ internlm/monitor/monitor.py | 4 +-- .../solver/optimizer/hybrid_zero_optim.py | 2 +- internlm/train/training_internlm.py | 8 +++-- train.py | 12 ++++--- 6 files changed, 51 insertions(+), 21 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 0ccc5e0..027d216 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -1,4 +1,5 @@ JOB_NAME = "7b_train" +DO_ALERT = False SEQ_LEN = 2048 HIDDEN_SIZE = 4096 @@ -150,3 +151,12 @@ parallel = dict( cudnn_deterministic = False cudnn_benchmark = False + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + ), +) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 0945337..5ad51fa 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -254,9 +254,22 @@ def args_sanity_check(): gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False ), "sequence parallel does not support use_flash_attn=False" - # feishu webhook address for alerting - if "alert_address" not in gpc.config: - gpc.config._add_item("alert_address", None) + # monitoring default config + monitor_default_config = { + "alert_address": None, # compatible with old alert config + "monitor": { # new monitoring config + "alert": {"enable_feishu_alert": False, "feishu_alert_address": None, "light_monitor_address": None} + }, + } + + for key, value in monitor_default_config.items(): + if key not in gpc.config: + gpc.config._add_item(key, value) + + alert = gpc.config.monitor.alert + + if alert.enable_feishu_alert and not alert.feishu_alert_address and gpc.is_rank_for_log(): + logger.warning("alert is enable but alert_address is not set") optim_ckpt = gpc.config.hybrid_zero_optimizer if "zero_overlap_communication" in optim_ckpt: @@ -334,14 +347,6 @@ def launch( f"tensor parallel size: {gpc.tensor_parallel_size}", ) - # init light monitor client - light_monitor_address = gpc.config.get("light_monitor_address", None) - if light_monitor_address is None: - if gpc.is_rank_for_log(): - logger.warning("monitor address is none, monitor could not be used!") - else: - initialize_light_monitor(light_monitor_address) - def launch_from_slurm( config: Union[str, Path, Config, Dict], @@ -446,6 +451,15 @@ def initialize_distributed_env( if args_check: args_sanity_check() + # init light monitor client + alert_config = gpc.config.monitor.alert + if alert_config.enable_feishu_alert and gpc.is_rank_for_log(): + light_monitor_address = alert_config.light_monitor_address + if light_monitor_address: + initialize_light_monitor(light_monitor_address) + else: + logger.warning("monitor address is none, monitor could not be used!") + def get_config_value(config, key, defalut): try: diff --git a/internlm/monitor/monitor.py b/internlm/monitor/monitor.py index ca5cf55..7d22691 100644 --- a/internlm/monitor/monitor.py +++ b/internlm/monitor/monitor.py @@ -218,9 +218,7 @@ def initialize_monitor_manager(job_name: str = None, alert_address: str = None): send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.") yield finally: - send_alert_message( - address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed." - ) + send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} completed.") monitor_manager.stop_monitor() else: yield diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 0e44c99..0c120f4 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -580,7 +580,7 @@ class HybridZeroOptimizer(BaseOptimizer): if gpc.is_rank_for_log(): logger.warning("Overflow occurs, please check it.") send_alert_message( - address=gpc.config.alert_address, + address=gpc.config.monitor.alert.feishu_alert_address, message="Overflow occurs, please check it.", ) self._grad_store._averaged_gradients = dict() diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 3a2e0bd..402d1ed 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -406,7 +406,7 @@ def record_current_batch_training_metrics( else: writer.add_scalar(key=key, value=value, step=train_state.step_count) - if gpc.config.get("light_monitor_address", None) and batch_count % 50 == 0: + if gpc.config.monitor.alert.get("light_monitor_address", None) and batch_count % 50 == 0: send_heartbeat("train_metrics", infos) if update_panel: @@ -434,4 +434,8 @@ def record_current_batch_training_metrics( logger.info(line) # if loss spike occurs, send alert info to feishu - mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item()) + mm.monitor_loss_spike( + alert_address=gpc.config.monitor.alert.feishu_alert_address, + step_count=batch_count, + cur_step_loss=loss.item(), + ) diff --git a/train.py b/train.py index dbdc09d..b9fe6af 100644 --- a/train.py +++ b/train.py @@ -120,7 +120,7 @@ def main(args): train_dl=train_dl, model_config=gpc.config.model, model_config_file="".join(config_lines), - feishu_address=gpc.config.alert_address, + feishu_address=gpc.config.monitor.alert.feishu_alert_address, ) # Loading other persistent training states. @@ -237,7 +237,7 @@ def main(args): if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case logger.warning(f"Warning: skip parameter update at step {batch_count}.") send_alert_message( - address=gpc.config.alert_address, + address=gpc.config.monitor.alert.feishu_alert_address, message=f"Warning: skip parameter update at step {batch_count}.", ) @@ -297,11 +297,15 @@ if __name__ == "__main__": assert hasattr(gpc, "config") and gpc.config is not None # initialize monitor manager context - with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address): + with initialize_monitor_manager( + job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address + ): try: main(args) except Exception: logger.error( f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", ) - mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc()) + mm.monitor_exception( + alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() + )