feat(monitor): support monitor and alert (#175)

* feat(monitor): support monitor and alert

* feat(monitor.py): fix demo error

* feat(monitor.py): move cmd monitor args to config file

* feat(hybrid_zero_optim.py): if overflow occurs send alert msg

* feat(monitor.py): remove alert msg filter

* feat(monitor.py): optimize class MonitorTracker

* feat(monitor.py): optimize code

* feat(monitor.py): optimize code

* feat(monitor.py): optimize code

* feat(monitor.py): optimize code

* feat(train.py): update print to log

* style(ci): fix lint error

* fix(utils/evaluation.py): remove useless code

* fix(model/modeling_internlm.py): fix lint error

---------

Co-authored-by: huangting4201 <huangting3@sensetime.com>
pull/190/head
huangting4201 2023-08-08 11:18:15 +08:00 committed by GitHub
parent c219065348
commit ff0fa7659f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 399 additions and 49 deletions

View File

@ -202,7 +202,13 @@ def args_sanity_check():
if "sequence_parallel" not in gpc.config.model:
gpc.config.model._add_item("sequence_parallel", False)
else:
assert not (gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False), "sequence parallel does not support use_flash_attn=False"
assert not (
gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False
), "sequence parallel does not support use_flash_attn=False"
# feishu webhook address for alerting
if "alert_address" not in gpc.config:
gpc.config._add_item("alert_address", None)
def launch(

View File

@ -55,10 +55,10 @@ class Embedding1D(nn.Module):
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
if gpc.config.model.sequence_parallel:
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
return output

View File

@ -58,7 +58,11 @@ class ScaleColumnParallelLinear(nn.Linear):
else:
weight = self.weight
return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
input,
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.model.sequence_parallel,
)
@ -103,7 +107,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
else:
weight = self.weight
return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
input,
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.model.sequence_parallel,
)
@ -170,7 +178,13 @@ class FeedForward(nn.Module):
dtype=dtype,
)
self.w2 = ColumnParallelLinearTorch(
in_features, hidden_features, process_group, bias, sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.model.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = RowParallelLinearTorch(
hidden_features,

View File

@ -31,6 +31,7 @@ MODEL_TYPE = "INTERNLM"
logger = get_logger(__file__)
RMSNorm = try_import_RMSNorm()
class PackedFlashBaseLayer1D(nn.Module):
"""
1D Packed Flash Base Layer.
@ -461,7 +462,7 @@ def build_model_with_cfg(
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
sequence_parallel: bool = False,
sequence_parallel: bool = False, # pylint: disable=W0613
):
"""
Builde model with config

View File

@ -16,6 +16,9 @@ from torch.cuda.amp import custom_bwd
from torch.distributed import ProcessGroup
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
def _split(input_, parallel_mode, dim=-1):
@ -84,6 +87,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
assert input.dtype == grad_output.dtype
grad_weight = torch.matmul(grad_output.t(), input)
@ -157,10 +161,11 @@ def fused_dense_func_torch(
else:
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
Args:
input_: input matrix.
parallel_mode: parallel mode.
@ -180,7 +185,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.mode, ctx.dim), None, None
def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
@ -189,14 +194,14 @@ def split_forward_gather_backward(input_, parallel_mode, dim):
def try_import_RMSNorm():
"""
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
"""
try:
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
return RMSNorm
except ModuleNotFoundError as e:
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
except ModuleNotFoundError:
logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
from internlm.model.norm import RMSNormTorch as RMSNorm
return RMSNorm

View File

@ -0,0 +1,4 @@
from .monitor import initialize_monitor_manager, send_alert_message
from .utils import set_env_var
__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"]

53
internlm/monitor/alert.py Normal file
View File

@ -0,0 +1,53 @@
import json
import time
import requests
def send_feishu_msg_with_webhook(webhook: str, title: str, message: str):
"""
Use Feishu robot to send messages with the given webhook.
Args:
webhook (str): The webhook to be used to send message.
title (str): The message title.
message (str): The message body.
Returns:
The response from the request. Or catch the exception and return None.
Raises:
Exception: An exception rasied by the HTTP post request.
"""
headers = {"Content-Type": "application/json;charset=utf-8"}
msg_body = {
"timestamp": int(time.time()),
"msg_type": "post",
"content": {
"post": {
"zh_cn": {
"title": title,
"content": [
[
{
"tag": "text",
"text": message,
},
],
],
},
},
},
}
try:
res = requests.post(webhook, data=json.dumps(msg_body), headers=headers, timeout=30)
res = res.json()
print(f"Feishu webhook response: {res}")
except Exception as err: # pylint: disable=W0703
print(f"HTTP Post error: {err}")
res = None
return res

226
internlm/monitor/monitor.py Normal file
View File

@ -0,0 +1,226 @@
import os
import signal
import socket
import time
from contextlib import contextmanager
from threading import Thread
from internlm.core.context import global_context as gpc
from internlm.monitor.alert import send_feishu_msg_with_webhook
from internlm.utils.common import SingletonMeta
from .utils import get_job_key, set_env_var
def send_alert_message(address: str = None, title: str = None, message: str = None):
"""
Send alert messages to the given Feishu webhook address in log rank.
Args:
address (str): The alert address to be used to send message, defaults to None.
title (str): The message title, defaults to None.
message (str): The message body, defaults to None.
"""
if address is not None and gpc.is_rank_for_log():
send_feishu_msg_with_webhook(
webhook=address,
title=title if title else get_job_key(),
message=message,
)
class MonitorTracker(Thread):
"""
Track job status and alert to Feishu during job training.
Args:
alert_address (str): The Feishu webhook address for sending alerting messages.
check_interval (float): The interval in seconds for monitoring checks. Defaults to 300.
loss_spike_limit (float): The threshold for detecting loss value spikes. Defaults to 1.5.
"""
def __init__(
self,
alert_address: str,
check_interval: float = 300,
loss_spike_limit: float = 1.5,
):
super().__init__()
self.alert_address = alert_address
self.check_interval = check_interval
self.loss_spike_limit = loss_spike_limit
self.last_active_time = -1
self.last_loss_value = -1
self.stopped = False
self.start()
def run(self):
"""
start the monitor tracker.
"""
while not self.stopped:
try:
self._check_stuck()
self._check_loss_spike()
except Exception:
continue
time.sleep(self.check_interval)
def _check_stuck(self):
"""
Check training status for potential stuck condition.
"""
new_active_time = -1
if os.getenv("LAST_ACTIVE_TIMESTAMP") is not None:
new_active_time = os.getenv("LAST_ACTIVE_TIMESTAMP")
if int(new_active_time) <= int(self.last_active_time) and new_active_time != -1:
self._send_alert("Training may be in stuck status, please check it.")
self.last_active_time = new_active_time
def _check_loss_spike(self):
"""
Check for loss value spikes.
"""
if gpc.is_rank_for_log():
new_loss_value = -1
new_step_id = -1
if os.getenv("LOSS") is not None:
new_loss_value = os.getenv("LOSS")
if os.getenv("STEP_ID") is not None:
new_step_id = os.getenv("STEP_ID")
if (float(new_loss_value) / float(self.last_loss_value)) > self.loss_spike_limit and new_loss_value != -1:
assert int(new_step_id) >= 0
self._send_alert(
f"Checking periodically: Loss spike may be happened in step {new_step_id}, "
f"loss value from {self.last_loss_value} to {new_loss_value}, please check it."
)
self.last_loss_value = new_loss_value
def _send_alert(self, message):
"""
Send alerting message to the Feishu webhook address.
Args:
message (str): The alerting message to be sent.
"""
send_alert_message(
address=self.alert_address,
message=message,
)
def stop(self):
"""
Stop the monitor tracker.
"""
self.stopped = True
class MonitorManager(metaclass=SingletonMeta):
"""
Monitor Manager for managing monitor thread and monitoring training status.
"""
def __init__(self, loss_spike_limit: float = 1.5) -> None:
self.monitor_thread = None
self.loss_spike_limit = loss_spike_limit
self.last_step_loss = -1
def monitor_loss_spike(self, alert_address: str = None, step_count: int = 0, cur_step_loss: float = 0.0):
"""Check loss value, if loss spike occurs, send alert message to Feishu."""
set_env_var(key="LOSS", value=cur_step_loss)
set_env_var(key="STEP_ID", value=step_count)
if self.last_step_loss != -1 and cur_step_loss > self.loss_spike_limit * self.last_step_loss:
send_alert_message(
address=alert_address,
message=(
f"Checking step by step: Loss spike may be happened in step {step_count}, "
f"loss value from {self.last_step_loss} to {cur_step_loss}, please check it."
),
)
self.last_step_loss = cur_step_loss
def monitor_exception(self, alert_address: str = None, excp_info: str = None):
"""Catch and format exception information, send alert message to Feishu."""
filtered_trace = excp_info.split("\n")[-10:]
format_trace = ""
for line in filtered_trace:
format_trace += "\n" + line
send_alert_message(
address=alert_address,
message=f"Catch Exception from {socket.gethostname()} with rank id {gpc.get_global_rank()}:{format_trace}",
)
def handle_sigterm(self, alert_address: str = None):
"""Catch SIGTERM signal, and send alert message to Feishu."""
def sigterm_handler(sys_signal, frame):
print("receive frame: ", frame)
print("receive signal: ", sys_signal)
send_alert_message(
address=alert_address,
message=f"Process received signal {signal} and exited.",
)
signal.signal(signal.SIGTERM, sigterm_handler)
def start_monitor(
self,
job_name: str,
alert_address: str,
monitor_interval_seconds: int = 300,
loss_spike_limit: float = 1.5,
):
"""
Initialize and start monitor thread for checking training job status, loss spike and so on.
Args:
job_name (str): The training job name.
alert_address (str): The Feishu webhook address for sending alert messages.
monitor_interval_seconds (int): The time of monitor interval in seconds, defaults to 300.
loss_spike_limit (float): The limit multiple of current loss to previous loss value, which means loss spike
may be occurs, defaults to 1.5.
"""
# initialize some variables for monitoring
set_env_var(key="JOB_NAME", value=job_name)
# start a monitor thread, periodically check the training status
self.monitor_thread = MonitorTracker(
alert_address=alert_address,
check_interval=monitor_interval_seconds,
loss_spike_limit=loss_spike_limit,
)
def stop_monitor(self):
"""Stop the monitor and alert thread."""
if self.monitor_thread is not None:
self.monitor_thread.stop()
monitor_manager = MonitorManager()
@contextmanager
def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
if alert_address is not None:
try:
monitor_manager.start_monitor(job_name=job_name, alert_address=alert_address)
monitor_manager.handle_sigterm(alert_address=alert_address)
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
yield
finally:
send_alert_message(
address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
)
monitor_manager.stop_monitor()
else:
yield

32
internlm/monitor/utils.py Normal file
View File

@ -0,0 +1,32 @@
import os
from datetime import datetime
def now_time():
return datetime.now().strftime("%b%d_%H-%M-%S")
def set_env_var(key, value):
os.environ[str(key)] = str(value)
def get_job_id():
job_id = "none"
if os.getenv("SLURM_JOB_ID") is not None:
job_id = os.getenv("SLURM_JOB_ID")
elif os.getenv("K8S_WORKSPACE_ID") is not None:
job_id = os.getenv("K8S_WORKSPACE_ID")
return job_id
def get_job_name():
job_name = f"unknown-{now_time()}"
if os.getenv("JOB_NAME") is not None:
job_name = os.getenv("JOB_NAME")
return job_name
def get_job_key():
return f"{get_job_id()}_{get_job_name()}"

View File

@ -28,6 +28,7 @@ from internlm.solver.optimizer.utils import (
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.monitor import send_alert_message
from .utils import compute_norm
@ -542,6 +543,7 @@ class HybridZeroOptimizer(BaseOptimizer):
if found_inf:
if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.")
send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.")
self._grad_store._averaged_gradients = dict()
self.zero_grad()
return False, None

View File

@ -34,18 +34,6 @@ def get_master_node():
return result
def get_process_rank():
proc_rank = -1
if os.getenv("SLURM_PROCID") is not None:
proc_rank = int(os.getenv("SLURM_PROCID"))
elif os.getenv("RANK") is not None:
# In k8s env, we use $RANK.
proc_rank = int(os.getenv("RANK"))
# assert proc_rank != -1, "get_process_rank cant't get right process rank!"
return proc_rank
def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
if torch.is_tensor(norm) and norm.device.type != "cuda":
norm = norm.to(torch.cuda.current_device())

View File

@ -6,8 +6,8 @@ from tqdm import tqdm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.metrics import AccPerplex
from internlm.core.scheduler import SchedulerMetricHook
from internlm.model.metrics import AccPerplex
@contextmanager
@ -90,15 +90,9 @@ def evaluate_on_val_dls(
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
if gpc.config.model.sequence_parallel:
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
)
else:
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
with switch_evaluation_pipeline_scheduler(
trainer=trainer,
@ -114,7 +108,6 @@ def evaluate_on_val_dls(
assert total_val_bsz % data_cfg.micro_bsz == 0
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
grad_accum_batch_size = data_cfg.micro_bsz
# import pdb; pdb.set_trace()
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer,
grad_accum_size=grad_accum_size,
@ -170,4 +163,4 @@ def switch_sequence_parallel_mode():
gpc.config.model.sequence_parallel = False
yield
finally:
gpc.config.model.sequence_parallel = prev_mode
gpc.config.model.sequence_parallel = prev_mode

View File

@ -30,6 +30,8 @@ from internlm.data.packed_dataset import (
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer
@ -37,7 +39,6 @@ from internlm.utils.common import (
BatchSkipper,
get_master_node,
get_megatron_flops,
get_process_rank,
launch_time,
parse_args,
)
@ -92,6 +93,15 @@ def initialize_distributed_env(config: str, launcher: str = "slurm", master_port
def initialize_llm_logger(start_time: str):
"""
Initialize customed uniscale logger.
Args:
start_time (str): The launch time of current training job.
Returns: The instance of uniscale logger.
"""
uniscale_logger = initialize_uniscale_logger(
job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
)
@ -213,6 +223,8 @@ def get_train_data_loader(num_worker: int = 0):
def get_validation_data_loader(num_worker: int = 0):
"""Generate and return the validation data loader."""
data_cfg = gpc.config.data
if not data_cfg.valid_folder:
@ -327,6 +339,8 @@ def record_current_batch_training_metrics(
Print some training metrics of current batch.
"""
set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
if success_update in (0, True):
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
if is_no_pp_or_last_stage():
@ -405,12 +419,11 @@ def record_current_batch_training_metrics(
else:
logger.info(line)
# if loss spike occurs, send alert info to feishu
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
def main(args):
# initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None
# init setting
skip_batches = gpc.config.data.skip_batches
total_steps = gpc.config.data.total_steps
@ -477,8 +490,8 @@ def main(args):
model_load_path = load_model_only_folder
else:
logger.info(
f"===========New Run {current_time} on host:{socket.gethostname()},"
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
)
@ -594,6 +607,9 @@ def main(args):
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message(
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
)
# calculate and record the training metrics, eg. loss, accuracy and so on.
record_current_batch_training_metrics(
@ -646,9 +662,19 @@ def main(args):
if __name__ == "__main__":
args = parse_args()
hostname = socket.gethostname()
try:
main(args)
except Exception:
print(f"Raise exception from {socket.gethostname()} with proc id: {get_process_rank()}")
traceback.print_exc()
# initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None
# initialize monitor manager context
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
try:
main(args)
except Exception:
logger.error(
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
exc_info=traceback.format_exc(),
)
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())