From ff0fa7659f148bb45e3086e4e3b1abecdfb3048a Mon Sep 17 00:00:00 2001
From: huangting4201 <1538303371@qq.com>
Date: Tue, 8 Aug 2023 11:18:15 +0800
Subject: [PATCH] feat(monitor): support monitor and alert (#175)

* feat(monitor): support monitor and alert

* feat(monitor.py): fix demo error

* feat(monitor.py): move cmd monitor args to config file

* feat(hybrid_zero_optim.py): if overflow occurs send alert msg

* feat(monitor.py): remove alert msg filter

* feat(monitor.py): optimize class MonitorTracker

* feat(monitor.py): optimize code

* feat(monitor.py): optimize code

* feat(monitor.py): optimize code

* feat(monitor.py): optimize code

* feat(train.py): update print to log

* style(ci): fix lint error

* fix(utils/evaluation.py): remove useless code

* fix(model/modeling_internlm.py): fix lint error

---------

Co-authored-by: huangting4201 <huangting3@sensetime.com>
---
 internlm/initialize/launch.py                 |   8 +-
 internlm/model/embedding.py                   |   4 +-
 internlm/model/linear.py                      |  20 +-
 internlm/model/modeling_internlm.py           |   3 +-
 internlm/model/utils.py                       |  17 +-
 internlm/monitor/__init__.py                  |   4 +
 internlm/monitor/alert.py                     |  53 ++++
 internlm/monitor/monitor.py                   | 226 ++++++++++++++++++
 internlm/monitor/utils.py                     |  32 +++
 .../solver/optimizer/hybrid_zero_optim.py     |   2 +
 internlm/utils/common.py                      |  12 -
 internlm/utils/evaluation.py                  |  17 +-
 train.py                                      |  50 +++-
 13 files changed, 399 insertions(+), 49 deletions(-)
 create mode 100644 internlm/monitor/__init__.py
 create mode 100644 internlm/monitor/alert.py
 create mode 100644 internlm/monitor/monitor.py
 create mode 100644 internlm/monitor/utils.py

diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py
index 1f60adc..dee6ffd 100644
--- a/internlm/initialize/launch.py
+++ b/internlm/initialize/launch.py
@@ -202,7 +202,13 @@ def args_sanity_check():
     if "sequence_parallel" not in gpc.config.model:
         gpc.config.model._add_item("sequence_parallel", False)
     else:
-        assert not (gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False), "sequence parallel does not support use_flash_attn=False"
+        assert not (
+            gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False
+        ), "sequence parallel does not support use_flash_attn=False"
+
+    # feishu webhook address for alerting
+    if "alert_address" not in gpc.config:
+        gpc.config._add_item("alert_address", None)
 
 
 def launch(
diff --git a/internlm/model/embedding.py b/internlm/model/embedding.py
index 0951ccd..d35b9c1 100644
--- a/internlm/model/embedding.py
+++ b/internlm/model/embedding.py
@@ -55,10 +55,10 @@ class Embedding1D(nn.Module):
         output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
 
         output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
-        
+
         if gpc.config.model.sequence_parallel:
             output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
-        
+
         return output
 
 
diff --git a/internlm/model/linear.py b/internlm/model/linear.py
index 2fa249c..50b4bf0 100644
--- a/internlm/model/linear.py
+++ b/internlm/model/linear.py
@@ -58,7 +58,11 @@ class ScaleColumnParallelLinear(nn.Linear):
         else:
             weight = self.weight
         return fused_dense_func_torch(
-            input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
+            input,
+            weight,
+            self.bias,
+            process_group=self.process_group,
+            sequence_parallel=gpc.config.model.sequence_parallel,
         )
 
 
@@ -103,7 +107,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
         else:
             weight = self.weight
         return fused_dense_func_torch(
-            input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
+            input,
+            weight,
+            self.bias,
+            process_group=self.process_group,
+            sequence_parallel=gpc.config.model.sequence_parallel,
         )
 
 
@@ -170,7 +178,13 @@ class FeedForward(nn.Module):
             dtype=dtype,
         )
         self.w2 = ColumnParallelLinearTorch(
-            in_features, hidden_features, process_group, bias, sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype
+            in_features,
+            hidden_features,
+            process_group,
+            bias,
+            sequence_parallel=gpc.config.model.sequence_parallel,
+            device=device,
+            dtype=dtype,
         )
         self.w3 = RowParallelLinearTorch(
             hidden_features,
diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py
index 31138fa..4a7a4ee 100644
--- a/internlm/model/modeling_internlm.py
+++ b/internlm/model/modeling_internlm.py
@@ -31,6 +31,7 @@ MODEL_TYPE = "INTERNLM"
 logger = get_logger(__file__)
 RMSNorm = try_import_RMSNorm()
 
+
 class PackedFlashBaseLayer1D(nn.Module):
     """
     1D Packed Flash Base Layer.
@@ -461,7 +462,7 @@ def build_model_with_cfg(
     use_scaled_init: bool = True,
     use_swiglu: bool = True,
     use_flash_attn: bool = True,
-    sequence_parallel: bool = False,
+    sequence_parallel: bool = False,  # pylint: disable=W0613
 ):
     """
     Builde model with config
diff --git a/internlm/model/utils.py b/internlm/model/utils.py
index a84f058..8b80af2 100644
--- a/internlm/model/utils.py
+++ b/internlm/model/utils.py
@@ -16,6 +16,9 @@ from torch.cuda.amp import custom_bwd
 from torch.distributed import ProcessGroup
 
 from internlm.core.context import global_context as gpc
+from internlm.utils.logger import get_logger
+
+logger = get_logger(__file__)
 
 
 def _split(input_, parallel_mode, dim=-1):
@@ -84,6 +87,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
 def gather_forward_split_backward(input_, parallel_mode, dim):
     return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
 
+
 def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
     assert input.dtype == grad_output.dtype
     grad_weight = torch.matmul(grad_output.t(), input)
@@ -157,10 +161,11 @@ def fused_dense_func_torch(
     else:
         return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
 
+
 class _SplitForwardGatherBackward(torch.autograd.Function):
     """
     Split the input and keep only the corresponding chuck to the rank.
-    
+
     Args:
         input_: input matrix.
         parallel_mode: parallel mode.
@@ -180,7 +185,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
     @staticmethod
     def backward(ctx, grad_output):
         return _gather(grad_output, ctx.mode, ctx.dim), None, None
-    
+
 
 def split_forward_gather_backward(input_, parallel_mode, dim):
     return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
@@ -189,14 +194,14 @@ def split_forward_gather_backward(input_, parallel_mode, dim):
 def try_import_RMSNorm():
     """
     Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
-    
+
     """
     try:
         from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
+
         return RMSNorm
-    except ModuleNotFoundError as e:
-        from internlm.utils.logger import get_logger
-        logger = get_logger(__file__)
+    except ModuleNotFoundError:
         logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
         from internlm.model.norm import RMSNormTorch as RMSNorm
+
         return RMSNorm
diff --git a/internlm/monitor/__init__.py b/internlm/monitor/__init__.py
new file mode 100644
index 0000000..b100cde
--- /dev/null
+++ b/internlm/monitor/__init__.py
@@ -0,0 +1,4 @@
+from .monitor import initialize_monitor_manager, send_alert_message
+from .utils import set_env_var
+
+__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"]
diff --git a/internlm/monitor/alert.py b/internlm/monitor/alert.py
new file mode 100644
index 0000000..78b6040
--- /dev/null
+++ b/internlm/monitor/alert.py
@@ -0,0 +1,53 @@
+import json
+import time
+
+import requests
+
+
+def send_feishu_msg_with_webhook(webhook: str, title: str, message: str):
+    """
+    Use Feishu robot to send messages with the given webhook.
+
+    Args:
+        webhook (str): The webhook to be used to send message.
+        title (str): The message title.
+        message (str): The message body.
+
+    Returns:
+        The response from the request. Or catch the exception and return None.
+
+    Raises:
+        Exception: An exception rasied by the HTTP post request.
+
+    """
+
+    headers = {"Content-Type": "application/json;charset=utf-8"}
+    msg_body = {
+        "timestamp": int(time.time()),
+        "msg_type": "post",
+        "content": {
+            "post": {
+                "zh_cn": {
+                    "title": title,
+                    "content": [
+                        [
+                            {
+                                "tag": "text",
+                                "text": message,
+                            },
+                        ],
+                    ],
+                },
+            },
+        },
+    }
+
+    try:
+        res = requests.post(webhook, data=json.dumps(msg_body), headers=headers, timeout=30)
+        res = res.json()
+        print(f"Feishu webhook response: {res}")
+    except Exception as err:  # pylint: disable=W0703
+        print(f"HTTP Post error: {err}")
+        res = None
+
+    return res
diff --git a/internlm/monitor/monitor.py b/internlm/monitor/monitor.py
new file mode 100644
index 0000000..ca5cf55
--- /dev/null
+++ b/internlm/monitor/monitor.py
@@ -0,0 +1,226 @@
+import os
+import signal
+import socket
+import time
+from contextlib import contextmanager
+from threading import Thread
+
+from internlm.core.context import global_context as gpc
+from internlm.monitor.alert import send_feishu_msg_with_webhook
+from internlm.utils.common import SingletonMeta
+
+from .utils import get_job_key, set_env_var
+
+
+def send_alert_message(address: str = None, title: str = None, message: str = None):
+    """
+    Send alert messages to the given Feishu webhook address in log rank.
+
+    Args:
+        address (str): The alert address to be used to send message, defaults to None.
+        title (str): The message title, defaults to None.
+        message (str): The message body, defaults to None.
+    """
+
+    if address is not None and gpc.is_rank_for_log():
+        send_feishu_msg_with_webhook(
+            webhook=address,
+            title=title if title else get_job_key(),
+            message=message,
+        )
+
+
+class MonitorTracker(Thread):
+    """
+    Track job status and alert to Feishu during job training.
+
+    Args:
+        alert_address (str): The Feishu webhook address for sending alerting messages.
+        check_interval (float): The interval in seconds for monitoring checks. Defaults to 300.
+        loss_spike_limit (float): The threshold for detecting loss value spikes. Defaults to 1.5.
+    """
+
+    def __init__(
+        self,
+        alert_address: str,
+        check_interval: float = 300,
+        loss_spike_limit: float = 1.5,
+    ):
+        super().__init__()
+        self.alert_address = alert_address
+        self.check_interval = check_interval
+        self.loss_spike_limit = loss_spike_limit
+        self.last_active_time = -1
+        self.last_loss_value = -1
+        self.stopped = False
+        self.start()
+
+    def run(self):
+        """
+        start the monitor tracker.
+        """
+
+        while not self.stopped:
+            try:
+                self._check_stuck()
+                self._check_loss_spike()
+            except Exception:
+                continue
+            time.sleep(self.check_interval)
+
+    def _check_stuck(self):
+        """
+        Check training status for potential stuck condition.
+        """
+
+        new_active_time = -1
+        if os.getenv("LAST_ACTIVE_TIMESTAMP") is not None:
+            new_active_time = os.getenv("LAST_ACTIVE_TIMESTAMP")
+        if int(new_active_time) <= int(self.last_active_time) and new_active_time != -1:
+            self._send_alert("Training may be in stuck status, please check it.")
+        self.last_active_time = new_active_time
+
+    def _check_loss_spike(self):
+        """
+        Check for loss value spikes.
+        """
+
+        if gpc.is_rank_for_log():
+            new_loss_value = -1
+            new_step_id = -1
+            if os.getenv("LOSS") is not None:
+                new_loss_value = os.getenv("LOSS")
+            if os.getenv("STEP_ID") is not None:
+                new_step_id = os.getenv("STEP_ID")
+
+            if (float(new_loss_value) / float(self.last_loss_value)) > self.loss_spike_limit and new_loss_value != -1:
+                assert int(new_step_id) >= 0
+                self._send_alert(
+                    f"Checking periodically: Loss spike may be happened in step {new_step_id}, "
+                    f"loss value from {self.last_loss_value} to {new_loss_value}, please check it."
+                )
+
+            self.last_loss_value = new_loss_value
+
+    def _send_alert(self, message):
+        """
+        Send alerting message to the Feishu webhook address.
+
+        Args:
+            message (str): The alerting message to be sent.
+        """
+
+        send_alert_message(
+            address=self.alert_address,
+            message=message,
+        )
+
+    def stop(self):
+        """
+        Stop the monitor tracker.
+        """
+
+        self.stopped = True
+
+
+class MonitorManager(metaclass=SingletonMeta):
+    """
+    Monitor Manager for managing monitor thread and monitoring training status.
+    """
+
+    def __init__(self, loss_spike_limit: float = 1.5) -> None:
+        self.monitor_thread = None
+        self.loss_spike_limit = loss_spike_limit
+        self.last_step_loss = -1
+
+    def monitor_loss_spike(self, alert_address: str = None, step_count: int = 0, cur_step_loss: float = 0.0):
+        """Check loss value, if loss spike occurs, send alert message to Feishu."""
+        set_env_var(key="LOSS", value=cur_step_loss)
+        set_env_var(key="STEP_ID", value=step_count)
+
+        if self.last_step_loss != -1 and cur_step_loss > self.loss_spike_limit * self.last_step_loss:
+            send_alert_message(
+                address=alert_address,
+                message=(
+                    f"Checking step by step: Loss spike may be happened in step {step_count}, "
+                    f"loss value from {self.last_step_loss} to {cur_step_loss}, please check it."
+                ),
+            )
+        self.last_step_loss = cur_step_loss
+
+    def monitor_exception(self, alert_address: str = None, excp_info: str = None):
+        """Catch and format exception information, send alert message to Feishu."""
+        filtered_trace = excp_info.split("\n")[-10:]
+        format_trace = ""
+        for line in filtered_trace:
+            format_trace += "\n" + line
+        send_alert_message(
+            address=alert_address,
+            message=f"Catch Exception from {socket.gethostname()} with rank id {gpc.get_global_rank()}:{format_trace}",
+        )
+
+    def handle_sigterm(self, alert_address: str = None):
+        """Catch SIGTERM signal, and send alert message to Feishu."""
+
+        def sigterm_handler(sys_signal, frame):
+            print("receive frame: ", frame)
+            print("receive signal: ", sys_signal)
+            send_alert_message(
+                address=alert_address,
+                message=f"Process received signal {signal} and exited.",
+            )
+
+        signal.signal(signal.SIGTERM, sigterm_handler)
+
+    def start_monitor(
+        self,
+        job_name: str,
+        alert_address: str,
+        monitor_interval_seconds: int = 300,
+        loss_spike_limit: float = 1.5,
+    ):
+        """
+        Initialize and start monitor thread for checking training job status, loss spike and so on.
+
+        Args:
+            job_name (str): The training job name.
+            alert_address (str): The Feishu webhook address for sending alert messages.
+            monitor_interval_seconds (int): The time of monitor interval in seconds, defaults to 300.
+            loss_spike_limit (float): The limit multiple of current loss to previous loss value, which means loss spike
+                may be occurs, defaults to 1.5.
+        """
+
+        # initialize some variables for monitoring
+        set_env_var(key="JOB_NAME", value=job_name)
+
+        # start a monitor thread, periodically check the training status
+        self.monitor_thread = MonitorTracker(
+            alert_address=alert_address,
+            check_interval=monitor_interval_seconds,
+            loss_spike_limit=loss_spike_limit,
+        )
+
+    def stop_monitor(self):
+        """Stop the monitor and alert thread."""
+        if self.monitor_thread is not None:
+            self.monitor_thread.stop()
+
+
+monitor_manager = MonitorManager()
+
+
+@contextmanager
+def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
+    if alert_address is not None:
+        try:
+            monitor_manager.start_monitor(job_name=job_name, alert_address=alert_address)
+            monitor_manager.handle_sigterm(alert_address=alert_address)
+            send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
+            yield
+        finally:
+            send_alert_message(
+                address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
+            )
+            monitor_manager.stop_monitor()
+    else:
+        yield
diff --git a/internlm/monitor/utils.py b/internlm/monitor/utils.py
new file mode 100644
index 0000000..f64c7dc
--- /dev/null
+++ b/internlm/monitor/utils.py
@@ -0,0 +1,32 @@
+import os
+from datetime import datetime
+
+
+def now_time():
+    return datetime.now().strftime("%b%d_%H-%M-%S")
+
+
+def set_env_var(key, value):
+    os.environ[str(key)] = str(value)
+
+
+def get_job_id():
+    job_id = "none"
+    if os.getenv("SLURM_JOB_ID") is not None:
+        job_id = os.getenv("SLURM_JOB_ID")
+    elif os.getenv("K8S_WORKSPACE_ID") is not None:
+        job_id = os.getenv("K8S_WORKSPACE_ID")
+
+    return job_id
+
+
+def get_job_name():
+    job_name = f"unknown-{now_time()}"
+    if os.getenv("JOB_NAME") is not None:
+        job_name = os.getenv("JOB_NAME")
+
+    return job_name
+
+
+def get_job_key():
+    return f"{get_job_id()}_{get_job_name()}"
diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py
index 116ffc2..618b772 100644
--- a/internlm/solver/optimizer/hybrid_zero_optim.py
+++ b/internlm/solver/optimizer/hybrid_zero_optim.py
@@ -28,6 +28,7 @@ from internlm.solver.optimizer.utils import (
 from internlm.utils.common import get_current_device
 from internlm.utils.logger import get_logger
 from internlm.utils.megatron_timers import megatron_timer as timer
+from internlm.monitor import send_alert_message
 
 from .utils import compute_norm
 
@@ -542,6 +543,7 @@ class HybridZeroOptimizer(BaseOptimizer):
         if found_inf:
             if gpc.is_rank_for_log():
                 logger.warning("Overflow occurs, please check it.")
+                send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.")
             self._grad_store._averaged_gradients = dict()
             self.zero_grad()
             return False, None
diff --git a/internlm/utils/common.py b/internlm/utils/common.py
index 584078f..d479284 100644
--- a/internlm/utils/common.py
+++ b/internlm/utils/common.py
@@ -34,18 +34,6 @@ def get_master_node():
     return result
 
 
-def get_process_rank():
-    proc_rank = -1
-    if os.getenv("SLURM_PROCID") is not None:
-        proc_rank = int(os.getenv("SLURM_PROCID"))
-    elif os.getenv("RANK") is not None:
-        # In k8s env, we use $RANK.
-        proc_rank = int(os.getenv("RANK"))
-
-    # assert proc_rank != -1, "get_process_rank cant't get right process rank!"
-    return proc_rank
-
-
 def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
     if torch.is_tensor(norm) and norm.device.type != "cuda":
         norm = norm.to(torch.cuda.current_device())
diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py
index 8424e16..d10f0c1 100644
--- a/internlm/utils/evaluation.py
+++ b/internlm/utils/evaluation.py
@@ -6,8 +6,8 @@ from tqdm import tqdm
 
 from internlm.core.context import ParallelMode
 from internlm.core.context import global_context as gpc
-from internlm.model.metrics import AccPerplex
 from internlm.core.scheduler import SchedulerMetricHook
+from internlm.model.metrics import AccPerplex
 
 
 @contextmanager
@@ -90,15 +90,9 @@ def evaluate_on_val_dls(
                     total_val_bsz = len(batch[1])
                     assert total_val_bsz % data_cfg.micro_bsz == 0
                     num_microbatches = total_val_bsz // data_cfg.micro_bsz
-                    if gpc.config.model.sequence_parallel:
-                        sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
-                        tensor_shape = torch.Size(
-                            [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
-                        )
-                    else:
-                        tensor_shape = torch.Size(
-                            [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
-                        )
+                    tensor_shape = torch.Size(
+                        [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
+                    )
 
                     with switch_evaluation_pipeline_scheduler(
                         trainer=trainer,
@@ -114,7 +108,6 @@ def evaluate_on_val_dls(
                     assert total_val_bsz % data_cfg.micro_bsz == 0
                     grad_accum_size = total_val_bsz // data_cfg.micro_bsz
                     grad_accum_batch_size = data_cfg.micro_bsz
-                    # import pdb; pdb.set_trace()
                     with switch_evaluation_no_pipeline_scheduler(
                         trainer=trainer,
                         grad_accum_size=grad_accum_size,
@@ -170,4 +163,4 @@ def switch_sequence_parallel_mode():
         gpc.config.model.sequence_parallel = False
         yield
     finally:
-        gpc.config.model.sequence_parallel = prev_mode
\ No newline at end of file
+        gpc.config.model.sequence_parallel = prev_mode
diff --git a/train.py b/train.py
index 59729e7..fa8d130 100644
--- a/train.py
+++ b/train.py
@@ -30,6 +30,8 @@ from internlm.data.packed_dataset import (
 from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
 from internlm.model.loss import FlashGPTLMLoss
 from internlm.model.metrics import AccPerplex
+from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
+from internlm.monitor.monitor import monitor_manager as mm
 from internlm.solver.beta2_scheduler import Beta2Scheduler
 from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
 from internlm.solver.optimizer import HybridZeroOptimizer
@@ -37,7 +39,6 @@ from internlm.utils.common import (
     BatchSkipper,
     get_master_node,
     get_megatron_flops,
-    get_process_rank,
     launch_time,
     parse_args,
 )
@@ -92,6 +93,15 @@ def initialize_distributed_env(config: str, launcher: str = "slurm", master_port
 
 
 def initialize_llm_logger(start_time: str):
+    """
+    Initialize customed uniscale logger.
+
+    Args:
+        start_time (str): The launch time of current training job.
+
+    Returns: The instance of uniscale logger.
+    """
+
     uniscale_logger = initialize_uniscale_logger(
         job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
     )
@@ -213,6 +223,8 @@ def get_train_data_loader(num_worker: int = 0):
 
 
 def get_validation_data_loader(num_worker: int = 0):
+    """Generate and return the validation data loader."""
+
     data_cfg = gpc.config.data
 
     if not data_cfg.valid_folder:
@@ -327,6 +339,8 @@ def record_current_batch_training_metrics(
     Print some training metrics of current batch.
     """
 
+    set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
+
     if success_update in (0, True):
         train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
     if is_no_pp_or_last_stage():
@@ -405,12 +419,11 @@ def record_current_batch_training_metrics(
         else:
             logger.info(line)
 
+        # if loss spike occurs, send alert info to feishu
+        mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
+
 
 def main(args):
-    # initialize distributed environment
-    initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
-    assert hasattr(gpc, "config") and gpc.config is not None
-
     # init setting
     skip_batches = gpc.config.data.skip_batches
     total_steps = gpc.config.data.total_steps
@@ -477,8 +490,8 @@ def main(args):
         model_load_path = load_model_only_folder
     else:
         logger.info(
-            f"===========New Run {current_time} on host:{socket.gethostname()},"
-            f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
+            f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
+            f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
             f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
         )
 
@@ -594,6 +607,9 @@ def main(args):
             train_state.inf_nan_skip_batches += 1  # record the amount of updating parameters unsuccessfully.
             if grad_norm == -99.0 and gpc.is_rank_for_log():  # -99.0 encodes a specific failure case
                 logger.warning(f"Warning: skip parameter update at step {batch_count}.")
+                send_alert_message(
+                    address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
+                )
 
         # calculate and record the training metrics, eg. loss, accuracy and so on.
         record_current_batch_training_metrics(
@@ -646,9 +662,19 @@ def main(args):
 
 if __name__ == "__main__":
     args = parse_args()
+    hostname = socket.gethostname()
 
-    try:
-        main(args)
-    except Exception:
-        print(f"Raise exception from {socket.gethostname()} with proc id: {get_process_rank()}")
-        traceback.print_exc()
+    # initialize distributed environment
+    initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
+    assert hasattr(gpc, "config") and gpc.config is not None
+
+    # initialize monitor manager context
+    with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
+        try:
+            main(args)
+        except Exception:
+            logger.error(
+                f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
+                exc_info=traceback.format_exc(),
+            )
+            mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())