feat(monitor): support monitor and alert (#175)

* feat(monitor): support monitor and alert

* feat(monitor.py): fix demo error

* feat(monitor.py): move cmd monitor args to config file

* feat(hybrid_zero_optim.py): if overflow occurs send alert msg

* feat(monitor.py): remove alert msg filter

* feat(monitor.py): optimize class MonitorTracker

* feat(monitor.py): optimize code

* feat(monitor.py): optimize code

* feat(monitor.py): optimize code

* feat(monitor.py): optimize code

* feat(train.py): update print to log

* style(ci): fix lint error

* fix(utils/evaluation.py): remove useless code

* fix(model/modeling_internlm.py): fix lint error

---------

Co-authored-by: huangting4201 <huangting3@sensetime.com>
pull/190/head
huangting4201 2023-08-08 11:18:15 +08:00 committed by GitHub
parent c219065348
commit ff0fa7659f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 399 additions and 49 deletions

View File

@ -202,7 +202,13 @@ def args_sanity_check():
if "sequence_parallel" not in gpc.config.model: if "sequence_parallel" not in gpc.config.model:
gpc.config.model._add_item("sequence_parallel", False) gpc.config.model._add_item("sequence_parallel", False)
else: else:
assert not (gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False), "sequence parallel does not support use_flash_attn=False" assert not (
gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False
), "sequence parallel does not support use_flash_attn=False"
# feishu webhook address for alerting
if "alert_address" not in gpc.config:
gpc.config._add_item("alert_address", None)
def launch( def launch(

View File

@ -55,10 +55,10 @@ class Embedding1D(nn.Module):
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1) output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
if gpc.config.model.sequence_parallel: if gpc.config.model.sequence_parallel:
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1) output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
return output return output

View File

@ -58,7 +58,11 @@ class ScaleColumnParallelLinear(nn.Linear):
else: else:
weight = self.weight weight = self.weight
return fused_dense_func_torch( return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel input,
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.model.sequence_parallel,
) )
@ -103,7 +107,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
else: else:
weight = self.weight weight = self.weight
return fused_dense_func_torch( return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel input,
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.model.sequence_parallel,
) )
@ -170,7 +178,13 @@ class FeedForward(nn.Module):
dtype=dtype, dtype=dtype,
) )
self.w2 = ColumnParallelLinearTorch( self.w2 = ColumnParallelLinearTorch(
in_features, hidden_features, process_group, bias, sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.model.sequence_parallel,
device=device,
dtype=dtype,
) )
self.w3 = RowParallelLinearTorch( self.w3 = RowParallelLinearTorch(
hidden_features, hidden_features,

View File

@ -31,6 +31,7 @@ MODEL_TYPE = "INTERNLM"
logger = get_logger(__file__) logger = get_logger(__file__)
RMSNorm = try_import_RMSNorm() RMSNorm = try_import_RMSNorm()
class PackedFlashBaseLayer1D(nn.Module): class PackedFlashBaseLayer1D(nn.Module):
""" """
1D Packed Flash Base Layer. 1D Packed Flash Base Layer.
@ -461,7 +462,7 @@ def build_model_with_cfg(
use_scaled_init: bool = True, use_scaled_init: bool = True,
use_swiglu: bool = True, use_swiglu: bool = True,
use_flash_attn: bool = True, use_flash_attn: bool = True,
sequence_parallel: bool = False, sequence_parallel: bool = False, # pylint: disable=W0613
): ):
""" """
Builde model with config Builde model with config

View File

@ -16,6 +16,9 @@ from torch.cuda.amp import custom_bwd
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
def _split(input_, parallel_mode, dim=-1): def _split(input_, parallel_mode, dim=-1):
@ -84,6 +87,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
def gather_forward_split_backward(input_, parallel_mode, dim): def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
def linear_bias_wgrad_torch(input, grad_output, has_d_bias): def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
assert input.dtype == grad_output.dtype assert input.dtype == grad_output.dtype
grad_weight = torch.matmul(grad_output.t(), input) grad_weight = torch.matmul(grad_output.t(), input)
@ -157,10 +161,11 @@ def fused_dense_func_torch(
else: else:
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel) return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
class _SplitForwardGatherBackward(torch.autograd.Function): class _SplitForwardGatherBackward(torch.autograd.Function):
""" """
Split the input and keep only the corresponding chuck to the rank. Split the input and keep only the corresponding chuck to the rank.
Args: Args:
input_: input matrix. input_: input matrix.
parallel_mode: parallel mode. parallel_mode: parallel mode.
@ -180,7 +185,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return _gather(grad_output, ctx.mode, ctx.dim), None, None return _gather(grad_output, ctx.mode, ctx.dim), None, None
def split_forward_gather_backward(input_, parallel_mode, dim): def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
@ -189,14 +194,14 @@ def split_forward_gather_backward(input_, parallel_mode, dim):
def try_import_RMSNorm(): def try_import_RMSNorm():
""" """
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
""" """
try: try:
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
return RMSNorm return RMSNorm
except ModuleNotFoundError as e: except ModuleNotFoundError:
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
from internlm.model.norm import RMSNormTorch as RMSNorm from internlm.model.norm import RMSNormTorch as RMSNorm
return RMSNorm return RMSNorm

View File

@ -0,0 +1,4 @@
from .monitor import initialize_monitor_manager, send_alert_message
from .utils import set_env_var
__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"]

53
internlm/monitor/alert.py Normal file
View File

@ -0,0 +1,53 @@
import json
import time
import requests
def send_feishu_msg_with_webhook(webhook: str, title: str, message: str):
"""
Use Feishu robot to send messages with the given webhook.
Args:
webhook (str): The webhook to be used to send message.
title (str): The message title.
message (str): The message body.
Returns:
The response from the request. Or catch the exception and return None.
Raises:
Exception: An exception rasied by the HTTP post request.
"""
headers = {"Content-Type": "application/json;charset=utf-8"}
msg_body = {
"timestamp": int(time.time()),
"msg_type": "post",
"content": {
"post": {
"zh_cn": {
"title": title,
"content": [
[
{
"tag": "text",
"text": message,
},
],
],
},
},
},
}
try:
res = requests.post(webhook, data=json.dumps(msg_body), headers=headers, timeout=30)
res = res.json()
print(f"Feishu webhook response: {res}")
except Exception as err: # pylint: disable=W0703
print(f"HTTP Post error: {err}")
res = None
return res

226
internlm/monitor/monitor.py Normal file
View File

@ -0,0 +1,226 @@
import os
import signal
import socket
import time
from contextlib import contextmanager
from threading import Thread
from internlm.core.context import global_context as gpc
from internlm.monitor.alert import send_feishu_msg_with_webhook
from internlm.utils.common import SingletonMeta
from .utils import get_job_key, set_env_var
def send_alert_message(address: str = None, title: str = None, message: str = None):
"""
Send alert messages to the given Feishu webhook address in log rank.
Args:
address (str): The alert address to be used to send message, defaults to None.
title (str): The message title, defaults to None.
message (str): The message body, defaults to None.
"""
if address is not None and gpc.is_rank_for_log():
send_feishu_msg_with_webhook(
webhook=address,
title=title if title else get_job_key(),
message=message,
)
class MonitorTracker(Thread):
"""
Track job status and alert to Feishu during job training.
Args:
alert_address (str): The Feishu webhook address for sending alerting messages.
check_interval (float): The interval in seconds for monitoring checks. Defaults to 300.
loss_spike_limit (float): The threshold for detecting loss value spikes. Defaults to 1.5.
"""
def __init__(
self,
alert_address: str,
check_interval: float = 300,
loss_spike_limit: float = 1.5,
):
super().__init__()
self.alert_address = alert_address
self.check_interval = check_interval
self.loss_spike_limit = loss_spike_limit
self.last_active_time = -1
self.last_loss_value = -1
self.stopped = False
self.start()
def run(self):
"""
start the monitor tracker.
"""
while not self.stopped:
try:
self._check_stuck()
self._check_loss_spike()
except Exception:
continue
time.sleep(self.check_interval)
def _check_stuck(self):
"""
Check training status for potential stuck condition.
"""
new_active_time = -1
if os.getenv("LAST_ACTIVE_TIMESTAMP") is not None:
new_active_time = os.getenv("LAST_ACTIVE_TIMESTAMP")
if int(new_active_time) <= int(self.last_active_time) and new_active_time != -1:
self._send_alert("Training may be in stuck status, please check it.")
self.last_active_time = new_active_time
def _check_loss_spike(self):
"""
Check for loss value spikes.
"""
if gpc.is_rank_for_log():
new_loss_value = -1
new_step_id = -1
if os.getenv("LOSS") is not None:
new_loss_value = os.getenv("LOSS")
if os.getenv("STEP_ID") is not None:
new_step_id = os.getenv("STEP_ID")
if (float(new_loss_value) / float(self.last_loss_value)) > self.loss_spike_limit and new_loss_value != -1:
assert int(new_step_id) >= 0
self._send_alert(
f"Checking periodically: Loss spike may be happened in step {new_step_id}, "
f"loss value from {self.last_loss_value} to {new_loss_value}, please check it."
)
self.last_loss_value = new_loss_value
def _send_alert(self, message):
"""
Send alerting message to the Feishu webhook address.
Args:
message (str): The alerting message to be sent.
"""
send_alert_message(
address=self.alert_address,
message=message,
)
def stop(self):
"""
Stop the monitor tracker.
"""
self.stopped = True
class MonitorManager(metaclass=SingletonMeta):
"""
Monitor Manager for managing monitor thread and monitoring training status.
"""
def __init__(self, loss_spike_limit: float = 1.5) -> None:
self.monitor_thread = None
self.loss_spike_limit = loss_spike_limit
self.last_step_loss = -1
def monitor_loss_spike(self, alert_address: str = None, step_count: int = 0, cur_step_loss: float = 0.0):
"""Check loss value, if loss spike occurs, send alert message to Feishu."""
set_env_var(key="LOSS", value=cur_step_loss)
set_env_var(key="STEP_ID", value=step_count)
if self.last_step_loss != -1 and cur_step_loss > self.loss_spike_limit * self.last_step_loss:
send_alert_message(
address=alert_address,
message=(
f"Checking step by step: Loss spike may be happened in step {step_count}, "
f"loss value from {self.last_step_loss} to {cur_step_loss}, please check it."
),
)
self.last_step_loss = cur_step_loss
def monitor_exception(self, alert_address: str = None, excp_info: str = None):
"""Catch and format exception information, send alert message to Feishu."""
filtered_trace = excp_info.split("\n")[-10:]
format_trace = ""
for line in filtered_trace:
format_trace += "\n" + line
send_alert_message(
address=alert_address,
message=f"Catch Exception from {socket.gethostname()} with rank id {gpc.get_global_rank()}:{format_trace}",
)
def handle_sigterm(self, alert_address: str = None):
"""Catch SIGTERM signal, and send alert message to Feishu."""
def sigterm_handler(sys_signal, frame):
print("receive frame: ", frame)
print("receive signal: ", sys_signal)
send_alert_message(
address=alert_address,
message=f"Process received signal {signal} and exited.",
)
signal.signal(signal.SIGTERM, sigterm_handler)
def start_monitor(
self,
job_name: str,
alert_address: str,
monitor_interval_seconds: int = 300,
loss_spike_limit: float = 1.5,
):
"""
Initialize and start monitor thread for checking training job status, loss spike and so on.
Args:
job_name (str): The training job name.
alert_address (str): The Feishu webhook address for sending alert messages.
monitor_interval_seconds (int): The time of monitor interval in seconds, defaults to 300.
loss_spike_limit (float): The limit multiple of current loss to previous loss value, which means loss spike
may be occurs, defaults to 1.5.
"""
# initialize some variables for monitoring
set_env_var(key="JOB_NAME", value=job_name)
# start a monitor thread, periodically check the training status
self.monitor_thread = MonitorTracker(
alert_address=alert_address,
check_interval=monitor_interval_seconds,
loss_spike_limit=loss_spike_limit,
)
def stop_monitor(self):
"""Stop the monitor and alert thread."""
if self.monitor_thread is not None:
self.monitor_thread.stop()
monitor_manager = MonitorManager()
@contextmanager
def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
if alert_address is not None:
try:
monitor_manager.start_monitor(job_name=job_name, alert_address=alert_address)
monitor_manager.handle_sigterm(alert_address=alert_address)
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
yield
finally:
send_alert_message(
address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
)
monitor_manager.stop_monitor()
else:
yield

32
internlm/monitor/utils.py Normal file
View File

@ -0,0 +1,32 @@
import os
from datetime import datetime
def now_time():
return datetime.now().strftime("%b%d_%H-%M-%S")
def set_env_var(key, value):
os.environ[str(key)] = str(value)
def get_job_id():
job_id = "none"
if os.getenv("SLURM_JOB_ID") is not None:
job_id = os.getenv("SLURM_JOB_ID")
elif os.getenv("K8S_WORKSPACE_ID") is not None:
job_id = os.getenv("K8S_WORKSPACE_ID")
return job_id
def get_job_name():
job_name = f"unknown-{now_time()}"
if os.getenv("JOB_NAME") is not None:
job_name = os.getenv("JOB_NAME")
return job_name
def get_job_key():
return f"{get_job_id()}_{get_job_name()}"

View File

@ -28,6 +28,7 @@ from internlm.solver.optimizer.utils import (
from internlm.utils.common import get_current_device from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.monitor import send_alert_message
from .utils import compute_norm from .utils import compute_norm
@ -542,6 +543,7 @@ class HybridZeroOptimizer(BaseOptimizer):
if found_inf: if found_inf:
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.") logger.warning("Overflow occurs, please check it.")
send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.")
self._grad_store._averaged_gradients = dict() self._grad_store._averaged_gradients = dict()
self.zero_grad() self.zero_grad()
return False, None return False, None

View File

@ -34,18 +34,6 @@ def get_master_node():
return result return result
def get_process_rank():
proc_rank = -1
if os.getenv("SLURM_PROCID") is not None:
proc_rank = int(os.getenv("SLURM_PROCID"))
elif os.getenv("RANK") is not None:
# In k8s env, we use $RANK.
proc_rank = int(os.getenv("RANK"))
# assert proc_rank != -1, "get_process_rank cant't get right process rank!"
return proc_rank
def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
if torch.is_tensor(norm) and norm.device.type != "cuda": if torch.is_tensor(norm) and norm.device.type != "cuda":
norm = norm.to(torch.cuda.current_device()) norm = norm.to(torch.cuda.current_device())

View File

@ -6,8 +6,8 @@ from tqdm import tqdm
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.metrics import AccPerplex
from internlm.core.scheduler import SchedulerMetricHook from internlm.core.scheduler import SchedulerMetricHook
from internlm.model.metrics import AccPerplex
@contextmanager @contextmanager
@ -90,15 +90,9 @@ def evaluate_on_val_dls(
total_val_bsz = len(batch[1]) total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0 assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz num_microbatches = total_val_bsz // data_cfg.micro_bsz
if gpc.config.model.sequence_parallel: tensor_shape = torch.Size(
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
tensor_shape = torch.Size( )
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
)
else:
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
with switch_evaluation_pipeline_scheduler( with switch_evaluation_pipeline_scheduler(
trainer=trainer, trainer=trainer,
@ -114,7 +108,6 @@ def evaluate_on_val_dls(
assert total_val_bsz % data_cfg.micro_bsz == 0 assert total_val_bsz % data_cfg.micro_bsz == 0
grad_accum_size = total_val_bsz // data_cfg.micro_bsz grad_accum_size = total_val_bsz // data_cfg.micro_bsz
grad_accum_batch_size = data_cfg.micro_bsz grad_accum_batch_size = data_cfg.micro_bsz
# import pdb; pdb.set_trace()
with switch_evaluation_no_pipeline_scheduler( with switch_evaluation_no_pipeline_scheduler(
trainer=trainer, trainer=trainer,
grad_accum_size=grad_accum_size, grad_accum_size=grad_accum_size,
@ -170,4 +163,4 @@ def switch_sequence_parallel_mode():
gpc.config.model.sequence_parallel = False gpc.config.model.sequence_parallel = False
yield yield
finally: finally:
gpc.config.model.sequence_parallel = prev_mode gpc.config.model.sequence_parallel = prev_mode

View File

@ -30,6 +30,8 @@ from internlm.data.packed_dataset import (
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.loss import FlashGPTLMLoss from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex from internlm.model.metrics import AccPerplex
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer import HybridZeroOptimizer
@ -37,7 +39,6 @@ from internlm.utils.common import (
BatchSkipper, BatchSkipper,
get_master_node, get_master_node,
get_megatron_flops, get_megatron_flops,
get_process_rank,
launch_time, launch_time,
parse_args, parse_args,
) )
@ -92,6 +93,15 @@ def initialize_distributed_env(config: str, launcher: str = "slurm", master_port
def initialize_llm_logger(start_time: str): def initialize_llm_logger(start_time: str):
"""
Initialize customed uniscale logger.
Args:
start_time (str): The launch time of current training job.
Returns: The instance of uniscale logger.
"""
uniscale_logger = initialize_uniscale_logger( uniscale_logger = initialize_uniscale_logger(
job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name() job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
) )
@ -213,6 +223,8 @@ def get_train_data_loader(num_worker: int = 0):
def get_validation_data_loader(num_worker: int = 0): def get_validation_data_loader(num_worker: int = 0):
"""Generate and return the validation data loader."""
data_cfg = gpc.config.data data_cfg = gpc.config.data
if not data_cfg.valid_folder: if not data_cfg.valid_folder:
@ -327,6 +339,8 @@ def record_current_batch_training_metrics(
Print some training metrics of current batch. Print some training metrics of current batch.
""" """
set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
if success_update in (0, True): if success_update in (0, True):
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
if is_no_pp_or_last_stage(): if is_no_pp_or_last_stage():
@ -405,12 +419,11 @@ def record_current_batch_training_metrics(
else: else:
logger.info(line) logger.info(line)
# if loss spike occurs, send alert info to feishu
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
def main(args): def main(args):
# initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None
# init setting # init setting
skip_batches = gpc.config.data.skip_batches skip_batches = gpc.config.data.skip_batches
total_steps = gpc.config.data.total_steps total_steps = gpc.config.data.total_steps
@ -477,8 +490,8 @@ def main(args):
model_load_path = load_model_only_folder model_load_path = load_model_only_folder
else: else:
logger.info( logger.info(
f"===========New Run {current_time} on host:{socket.gethostname()}," f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
) )
@ -594,6 +607,9 @@ def main(args):
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.") logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message(
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
)
# calculate and record the training metrics, eg. loss, accuracy and so on. # calculate and record the training metrics, eg. loss, accuracy and so on.
record_current_batch_training_metrics( record_current_batch_training_metrics(
@ -646,9 +662,19 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
hostname = socket.gethostname()
try: # initialize distributed environment
main(args) initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
except Exception: assert hasattr(gpc, "config") and gpc.config is not None
print(f"Raise exception from {socket.gethostname()} with proc id: {get_process_rank()}")
traceback.print_exc() # initialize monitor manager context
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
try:
main(args)
except Exception:
logger.error(
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
exc_info=traceback.format_exc(),
)
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())