From d680876a9af3f247d0f11d06466513b33f075ea1 Mon Sep 17 00:00:00 2001 From: JiaoPL Date: Fri, 10 Nov 2023 16:39:59 +0800 Subject: [PATCH] monitor task only if DO_ALERT is True --- internlm/monitor/monitor.py | 120 ++++++++++++++++++------------------ 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/internlm/monitor/monitor.py b/internlm/monitor/monitor.py index 25ad67f..35ebdd3 100644 --- a/internlm/monitor/monitor.py +++ b/internlm/monitor/monitor.py @@ -136,24 +136,26 @@ class MonitorManager(metaclass=SingletonMeta): self.last_step_loss = -1 self.send_exception = try_import_send_exception() self.alert_file_path = None + self.enable_alert = False + self.light_monitor_address = None def monitor_loss_spike(self, alert_address: str = None, step_count: int = 0, cur_step_loss: float = 0.0): """Check loss value, if loss spike occurs, send alert message to Feishu.""" - set_env_var(key="LOSS", value=cur_step_loss) - set_env_var(key="STEP_ID", value=step_count) + if self.enable_alert: + set_env_var(key="LOSS", value=cur_step_loss) + set_env_var(key="STEP_ID", value=step_count) - if self.last_step_loss != -1 and cur_step_loss > self.loss_spike_limit * self.last_step_loss: - send_alert_message( - address=alert_address, - message=( - f"Checking step by step: Loss spike may be happened in step {step_count}, " - f"loss value from {self.last_step_loss} to {cur_step_loss}, please check it." - ), - ) - self.last_step_loss = cur_step_loss + if self.last_step_loss != -1 and cur_step_loss > self.loss_spike_limit * self.last_step_loss: + send_alert_message( + address=alert_address, + message=( + f"Checking step by step: Loss spike may be happened in step {step_count}, " + f"loss value from {self.last_step_loss} to {cur_step_loss}, please check it." + ), + ) + self.last_step_loss = cur_step_loss def exception_should_be_alert(self, msg: str, alert_address: str = None): - enable_alert = gpc.config.monitor.alert.get("enable_feishu_alert", False) try: with open(self.alert_file_path, "a+") as f: fcntl.flock(f, fcntl.LOCK_EX) @@ -165,54 +167,49 @@ class MonitorManager(metaclass=SingletonMeta): f.write(msg) fcntl.flock(f, fcntl.LOCK_UN) - return enable_alert and True + return True except Exception as err: send_alert_message( address=alert_address, message=f"Failed to open ALERT file: {err}", ) - return enable_alert and True + return True def monitor_exception(self, alert_address: str = None, excp_info: str = None): """Catch and format exception information, send alert message to Feishu.""" - filtered_trace = excp_info.split("\n")[-10:] - format_trace = "" - for line in filtered_trace: - format_trace += "\n" + line - if ( - self.send_exception - and gpc.config.monitor.alert.get("enable_feishu_alert", False) - and gpc.config.monitor.alert.get("light_monitor_address", None) - ): - self.send_exception(format_trace, gpc.get_global_rank()) - message = f"Catch Exception from {socket.gethostname()} with rank id {gpc.get_global_rank()}:{format_trace}" - if self.alert_file_path: - if self.exception_should_be_alert(format_trace, alert_address): - send_feishu_msg_with_webhook( - webhook=alert_address, - title=get_job_key(), - message=message, - ) - else: - send_alert_message(alert_address, message) + if self.enable_alert: + filtered_trace = excp_info.split("\n")[-10:] + format_trace = "" + for line in filtered_trace: + format_trace += "\n" + line + + if self.send_exception and self.light_monitor_address: + self.send_exception(format_trace, gpc.get_global_rank()) + message = f"Catch Exception from {socket.gethostname()} with rank id {gpc.get_global_rank()}:{format_trace}" + if self.alert_file_path: + if self.exception_should_be_alert(format_trace, alert_address): + send_feishu_msg_with_webhook( + webhook=alert_address, + title=get_job_key(), + message=message, + ) + else: + send_alert_message(alert_address, message) def handle_sigterm(self, alert_address: str = None): """Catch SIGTERM signal, and send alert message to Feishu.""" def sigterm_handler(sys_signal, frame): - print("receive frame: ", frame) - print("receive signal: ", sys_signal) - message = f"Process received signal {signal} and exited." - if ( - self.send_exception - and gpc.config.monitor.alert.get("enable_feishu_alert", False) - and gpc.config.monitor.alert.get("light_monitor_address", None) - ): - self.send_exception(message, gpc.get_global_rank()) - send_alert_message( - address=alert_address, - message=message, - ) + if self.enable_alert: + print("receive frame: ", frame) + print("receive signal: ", sys_signal) + message = f"Process received signal {signal} and exited." + if self.send_exception and self.light_monitor_address: + self.send_exception(message, gpc.get_global_rank()) + send_alert_message( + address=alert_address, + message=message, + ) signal.signal(signal.SIGTERM, sigterm_handler) @@ -236,21 +233,24 @@ class MonitorManager(metaclass=SingletonMeta): # initialize some variables for monitoring set_env_var(key="JOB_NAME", value=job_name) + self.enable_alert = gpc.config.monitor.alert.get("enable_feishu_alert", False) - # initialize alert file - self.alert_file_path = gpc.config.monitor.alert.get("alert_file_path") - if self.alert_file_path and gpc.is_rank_for_log(): - alert_file_dir = os.path.dirname(self.alert_file_path) - os.makedirs(alert_file_dir, exist_ok=True) - if os.path.exists(self.alert_file_path): - os.remove(self.alert_file_path) + if self.enable_alert: + self.light_monitor_address = gpc.config.monitor.alert.get("light_monitor_address", None) + # initialize alert file + self.alert_file_path = gpc.config.monitor.alert.get("alert_file_path") + if self.alert_file_path and gpc.is_rank_for_log(): + alert_file_dir = os.path.dirname(self.alert_file_path) + os.makedirs(alert_file_dir, exist_ok=True) + if os.path.exists(self.alert_file_path): + os.remove(self.alert_file_path) - # start a monitor thread, periodically check the training status - self.monitor_thread = MonitorTracker( - alert_address=alert_address, - check_interval=monitor_interval_seconds, - loss_spike_limit=loss_spike_limit, - ) + # start a monitor thread, periodically check the training status + self.monitor_thread = MonitorTracker( + alert_address=alert_address, + check_interval=monitor_interval_seconds, + loss_spike_limit=loss_spike_limit, + ) def stop_monitor(self): """Stop the monitor and alert thread."""