modify code submit

pull/383/head
haijunlv 2023-09-28 14:09:26 +08:00
parent ff181bc5f8
commit 870dd7ddc6
3 changed files with 30 additions and 18 deletions

View File

@ -253,9 +253,22 @@ def args_sanity_check():
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
), "sequence parallel does not support use_flash_attn=False" ), "sequence parallel does not support use_flash_attn=False"
# feishu webhook address for alerting # monitoring default config
if "alert_address" not in gpc.config: monitor_default_config = {
gpc.config._add_item("alert_address", None) "alert_address": None, # compatible with old alert config
"monitor": { # new monitoring config
"alert": {"enable_feishu_alert": False, "feishu_alert_address": None, "light_monitor_address": None}
},
}
for key, value in monitor_default_config.items():
if key not in gpc.config:
gpc.config._add_item(key, value)
alert = gpc.config.monitor.alert
if alert.enable_feishu_alert and not alert.feishu_alert_address and gpc.is_rank_for_log():
logger.warning("alert is enable but alert_address is not set")
optim_ckpt = gpc.config.hybrid_zero_optimizer optim_ckpt = gpc.config.hybrid_zero_optimizer
if "zero_overlap_communication" in optim_ckpt: if "zero_overlap_communication" in optim_ckpt:
@ -333,14 +346,6 @@ def launch(
f"tensor parallel size: {gpc.tensor_parallel_size}", f"tensor parallel size: {gpc.tensor_parallel_size}",
) )
# init light monitor client
light_monitor_address = gpc.config.get("light_monitor_address", None)
if light_monitor_address is None:
if gpc.is_rank_for_log():
logger.warning("monitor address is none, monitor could not be used!")
else:
initialize_light_monitor(light_monitor_address)
def launch_from_slurm( def launch_from_slurm(
config: Union[str, Path, Config, Dict], config: Union[str, Path, Config, Dict],
@ -444,6 +449,15 @@ def initialize_distributed_env(
if args_check: if args_check:
args_sanity_check() args_sanity_check()
# init light monitor client
alert_config = gpc.config.monitor.alert
if alert_config.enable_feishu_alert and gpc.is_rank_for_log():
light_monitor_address = alert_config.light_monitor_address
if light_monitor_address:
initialize_light_monitor(light_monitor_address)
else:
logger.warning("monitor address is none, monitor could not be used!")
def get_config_value(config, key, defalut): def get_config_value(config, key, defalut):
try: try:

View File

@ -218,9 +218,7 @@ def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.") send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
yield yield
finally: finally:
send_alert_message( send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} completed.")
address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
)
monitor_manager.stop_monitor() monitor_manager.stop_monitor()
else: else:
yield yield

View File

@ -578,7 +578,7 @@ class HybridZeroOptimizer(BaseOptimizer):
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.") logger.warning("Overflow occurs, please check it.")
send_alert_message( send_alert_message(
address=gpc.config.alert_address, address=gpc.config.monitor.alert.feishu_alert_address,
message="Overflow occurs, please check it.", message="Overflow occurs, please check it.",
) )
self._grad_store._averaged_gradients = dict() self._grad_store._averaged_gradients = dict()