modify code submit

pull/383/head
haijunlv 2023-09-28 14:09:26 +08:00
parent ff181bc5f8
commit 870dd7ddc6
3 changed files with 30 additions and 18 deletions

View File

@ -253,9 +253,22 @@ def args_sanity_check():
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
), "sequence parallel does not support use_flash_attn=False"
# feishu webhook address for alerting
if "alert_address" not in gpc.config:
gpc.config._add_item("alert_address", None)
# monitoring default config
monitor_default_config = {
"alert_address": None, # compatible with old alert config
"monitor": { # new monitoring config
"alert": {"enable_feishu_alert": False, "feishu_alert_address": None, "light_monitor_address": None}
},
}
for key, value in monitor_default_config.items():
if key not in gpc.config:
gpc.config._add_item(key, value)
alert = gpc.config.monitor.alert
if alert.enable_feishu_alert and not alert.feishu_alert_address and gpc.is_rank_for_log():
logger.warning("alert is enable but alert_address is not set")
optim_ckpt = gpc.config.hybrid_zero_optimizer
if "zero_overlap_communication" in optim_ckpt:
@ -333,14 +346,6 @@ def launch(
f"tensor parallel size: {gpc.tensor_parallel_size}",
)
# init light monitor client
light_monitor_address = gpc.config.get("light_monitor_address", None)
if light_monitor_address is None:
if gpc.is_rank_for_log():
logger.warning("monitor address is none, monitor could not be used!")
else:
initialize_light_monitor(light_monitor_address)
def launch_from_slurm(
config: Union[str, Path, Config, Dict],
@ -444,10 +449,19 @@ def initialize_distributed_env(
if args_check:
args_sanity_check()
# init light monitor client
alert_config = gpc.config.monitor.alert
if alert_config.enable_feishu_alert and gpc.is_rank_for_log():
light_monitor_address = alert_config.light_monitor_address
if light_monitor_address:
initialize_light_monitor(light_monitor_address)
else:
logger.warning("monitor address is none, monitor could not be used!")
def get_config_value(config, key, defalut):
try:
value = config[key]
except KeyError:
value = defalut
return value
return value

View File

@ -218,9 +218,7 @@ def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
yield
finally:
send_alert_message(
address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
)
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} completed.")
monitor_manager.stop_monitor()
else:
yield
yield

View File

@ -578,7 +578,7 @@ class HybridZeroOptimizer(BaseOptimizer):
if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.")
send_alert_message(
address=gpc.config.alert_address,
address=gpc.config.monitor.alert.feishu_alert_address,
message="Overflow occurs, please check it.",
)
self._grad_store._averaged_gradients = dict()
@ -784,4 +784,4 @@ def reload_zero_fp32_buff(optimizer):
optimizer._zero_local_rank, group_id
)
# param_group["params"] is fp32 flatten optimizer states of this zero rank.
param_group["params"][0].data.copy_(fp16_flat_current_rank.float())
param_group["params"][0].data.copy_(fp16_flat_current_rank.float())