fix(monitor): add alert switch and refactor monitor config (#285)

* add monitor switch

* add switch to light monitor

* fix alert_address is empty

* fix light monitor heartbeat

* init light_monitor on rank_log only

* add comments to the monitoring config

* optimize config
pull/298/head
jiaopenglong 2023-09-07 21:49:05 +08:00 committed by GitHub
parent 37b8c6684e
commit 7c99e01ca7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 51 additions and 21 deletions

View File

@ -1,4 +1,5 @@
JOB_NAME = "7b_train"
DO_ALERT = False
SEQ_LEN = 2048
HIDDEN_SIZE = 4096
@ -150,3 +151,12 @@ parallel = dict(
cudnn_deterministic = False
cudnn_benchmark = False
monitor = dict(
# feishu alert configs
alert=dict(
enable_feishu_alert=DO_ALERT,
feishu_alert_address=None, # feishu webhook to send alert message
light_monitor_address=None, # light_monitor address to send heartbeat
),
)

View File

@ -254,9 +254,22 @@ def args_sanity_check():
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
), "sequence parallel does not support use_flash_attn=False"
# feishu webhook address for alerting
if "alert_address" not in gpc.config:
gpc.config._add_item("alert_address", None)
# monitoring default config
monitor_default_config = {
"alert_address": None, # compatible with old alert config
"monitor": { # new monitoring config
"alert": {"enable_feishu_alert": False, "feishu_alert_address": None, "light_monitor_address": None}
},
}
for key, value in monitor_default_config.items():
if key not in gpc.config:
gpc.config._add_item(key, value)
alert = gpc.config.monitor.alert
if alert.enable_feishu_alert and not alert.feishu_alert_address and gpc.is_rank_for_log():
logger.warning("alert is enable but alert_address is not set")
optim_ckpt = gpc.config.hybrid_zero_optimizer
if "zero_overlap_communication" in optim_ckpt:
@ -334,14 +347,6 @@ def launch(
f"tensor parallel size: {gpc.tensor_parallel_size}",
)
# init light monitor client
light_monitor_address = gpc.config.get("light_monitor_address", None)
if light_monitor_address is None:
if gpc.is_rank_for_log():
logger.warning("monitor address is none, monitor could not be used!")
else:
initialize_light_monitor(light_monitor_address)
def launch_from_slurm(
config: Union[str, Path, Config, Dict],
@ -446,6 +451,15 @@ def initialize_distributed_env(
if args_check:
args_sanity_check()
# init light monitor client
alert_config = gpc.config.monitor.alert
if alert_config.enable_feishu_alert and gpc.is_rank_for_log():
light_monitor_address = alert_config.light_monitor_address
if light_monitor_address:
initialize_light_monitor(light_monitor_address)
else:
logger.warning("monitor address is none, monitor could not be used!")
def get_config_value(config, key, defalut):
try:

View File

@ -218,9 +218,7 @@ def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
yield
finally:
send_alert_message(
address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
)
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} completed.")
monitor_manager.stop_monitor()
else:
yield

View File

@ -580,7 +580,7 @@ class HybridZeroOptimizer(BaseOptimizer):
if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.")
send_alert_message(
address=gpc.config.alert_address,
address=gpc.config.monitor.alert.feishu_alert_address,
message="Overflow occurs, please check it.",
)
self._grad_store._averaged_gradients = dict()

View File

@ -406,7 +406,7 @@ def record_current_batch_training_metrics(
else:
writer.add_scalar(key=key, value=value, step=train_state.step_count)
if gpc.config.get("light_monitor_address", None) and batch_count % 50 == 0:
if gpc.config.monitor.alert.get("light_monitor_address", None) and batch_count % 50 == 0:
send_heartbeat("train_metrics", infos)
if update_panel:
@ -434,4 +434,8 @@ def record_current_batch_training_metrics(
logger.info(line)
# if loss spike occurs, send alert info to feishu
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
mm.monitor_loss_spike(
alert_address=gpc.config.monitor.alert.feishu_alert_address,
step_count=batch_count,
cur_step_loss=loss.item(),
)

View File

@ -120,7 +120,7 @@ def main(args):
train_dl=train_dl,
model_config=gpc.config.model,
model_config_file="".join(config_lines),
feishu_address=gpc.config.alert_address,
feishu_address=gpc.config.monitor.alert.feishu_alert_address,
)
# Loading other persistent training states.
@ -237,7 +237,7 @@ def main(args):
if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message(
address=gpc.config.alert_address,
address=gpc.config.monitor.alert.feishu_alert_address,
message=f"Warning: skip parameter update at step {batch_count}.",
)
@ -297,11 +297,15 @@ if __name__ == "__main__":
assert hasattr(gpc, "config") and gpc.config is not None
# initialize monitor manager context
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
with initialize_monitor_manager(
job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address
):
try:
main(args)
except Exception:
logger.error(
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
)
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
mm.monitor_exception(
alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc()
)