mirror of https://github.com/InternLM/InternLM
feat(monitor): support monitor and alert (#175)
* feat(monitor): support monitor and alert * feat(monitor.py): fix demo error * feat(monitor.py): move cmd monitor args to config file * feat(hybrid_zero_optim.py): if overflow occurs send alert msg * feat(monitor.py): remove alert msg filter * feat(monitor.py): optimize class MonitorTracker * feat(monitor.py): optimize code * feat(monitor.py): optimize code * feat(monitor.py): optimize code * feat(monitor.py): optimize code * feat(train.py): update print to log * style(ci): fix lint error * fix(utils/evaluation.py): remove useless code * fix(model/modeling_internlm.py): fix lint error --------- Co-authored-by: huangting4201 <huangting3@sensetime.com>pull/190/head
parent
c219065348
commit
ff0fa7659f
|
@ -202,7 +202,13 @@ def args_sanity_check():
|
|||
if "sequence_parallel" not in gpc.config.model:
|
||||
gpc.config.model._add_item("sequence_parallel", False)
|
||||
else:
|
||||
assert not (gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False), "sequence parallel does not support use_flash_attn=False"
|
||||
assert not (
|
||||
gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
||||
), "sequence parallel does not support use_flash_attn=False"
|
||||
|
||||
# feishu webhook address for alerting
|
||||
if "alert_address" not in gpc.config:
|
||||
gpc.config._add_item("alert_address", None)
|
||||
|
||||
|
||||
def launch(
|
||||
|
|
|
@ -55,10 +55,10 @@ class Embedding1D(nn.Module):
|
|||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
||||
|
||||
|
||||
if gpc.config.model.sequence_parallel:
|
||||
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
||||
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
|
|
@ -58,7 +58,11 @@ class ScaleColumnParallelLinear(nn.Linear):
|
|||
else:
|
||||
weight = self.weight
|
||||
return fused_dense_func_torch(
|
||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
|
||||
input,
|
||||
weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
)
|
||||
|
||||
|
||||
|
@ -103,7 +107,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
|||
else:
|
||||
weight = self.weight
|
||||
return fused_dense_func_torch(
|
||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
|
||||
input,
|
||||
weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
)
|
||||
|
||||
|
||||
|
@ -170,7 +178,13 @@ class FeedForward(nn.Module):
|
|||
dtype=dtype,
|
||||
)
|
||||
self.w2 = ColumnParallelLinearTorch(
|
||||
in_features, hidden_features, process_group, bias, sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w3 = RowParallelLinearTorch(
|
||||
hidden_features,
|
||||
|
|
|
@ -31,6 +31,7 @@ MODEL_TYPE = "INTERNLM"
|
|||
logger = get_logger(__file__)
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
|
||||
|
||||
class PackedFlashBaseLayer1D(nn.Module):
|
||||
"""
|
||||
1D Packed Flash Base Layer.
|
||||
|
@ -461,7 +462,7 @@ def build_model_with_cfg(
|
|||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
sequence_parallel: bool = False,
|
||||
sequence_parallel: bool = False, # pylint: disable=W0613
|
||||
):
|
||||
"""
|
||||
Builde model with config
|
||||
|
|
|
@ -16,6 +16,9 @@ from torch.cuda.amp import custom_bwd
|
|||
from torch.distributed import ProcessGroup
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def _split(input_, parallel_mode, dim=-1):
|
||||
|
@ -84,6 +87,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
|||
def gather_forward_split_backward(input_, parallel_mode, dim):
|
||||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
||||
|
||||
|
||||
def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
|
||||
assert input.dtype == grad_output.dtype
|
||||
grad_weight = torch.matmul(grad_output.t(), input)
|
||||
|
@ -157,10 +161,11 @@ def fused_dense_func_torch(
|
|||
else:
|
||||
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""
|
||||
Split the input and keep only the corresponding chuck to the rank.
|
||||
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
parallel_mode: parallel mode.
|
||||
|
@ -180,7 +185,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather(grad_output, ctx.mode, ctx.dim), None, None
|
||||
|
||||
|
||||
|
||||
def split_forward_gather_backward(input_, parallel_mode, dim):
|
||||
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
|
||||
|
@ -189,14 +194,14 @@ def split_forward_gather_backward(input_, parallel_mode, dim):
|
|||
def try_import_RMSNorm():
|
||||
"""
|
||||
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
|
||||
|
||||
|
||||
"""
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
||||
|
||||
return RMSNorm
|
||||
except ModuleNotFoundError as e:
|
||||
from internlm.utils.logger import get_logger
|
||||
logger = get_logger(__file__)
|
||||
except ModuleNotFoundError:
|
||||
logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
|
||||
from internlm.model.norm import RMSNormTorch as RMSNorm
|
||||
|
||||
return RMSNorm
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .monitor import initialize_monitor_manager, send_alert_message
|
||||
from .utils import set_env_var
|
||||
|
||||
__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"]
|
|
@ -0,0 +1,53 @@
|
|||
import json
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def send_feishu_msg_with_webhook(webhook: str, title: str, message: str):
|
||||
"""
|
||||
Use Feishu robot to send messages with the given webhook.
|
||||
|
||||
Args:
|
||||
webhook (str): The webhook to be used to send message.
|
||||
title (str): The message title.
|
||||
message (str): The message body.
|
||||
|
||||
Returns:
|
||||
The response from the request. Or catch the exception and return None.
|
||||
|
||||
Raises:
|
||||
Exception: An exception rasied by the HTTP post request.
|
||||
|
||||
"""
|
||||
|
||||
headers = {"Content-Type": "application/json;charset=utf-8"}
|
||||
msg_body = {
|
||||
"timestamp": int(time.time()),
|
||||
"msg_type": "post",
|
||||
"content": {
|
||||
"post": {
|
||||
"zh_cn": {
|
||||
"title": title,
|
||||
"content": [
|
||||
[
|
||||
{
|
||||
"tag": "text",
|
||||
"text": message,
|
||||
},
|
||||
],
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
res = requests.post(webhook, data=json.dumps(msg_body), headers=headers, timeout=30)
|
||||
res = res.json()
|
||||
print(f"Feishu webhook response: {res}")
|
||||
except Exception as err: # pylint: disable=W0703
|
||||
print(f"HTTP Post error: {err}")
|
||||
res = None
|
||||
|
||||
return res
|
|
@ -0,0 +1,226 @@
|
|||
import os
|
||||
import signal
|
||||
import socket
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from threading import Thread
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.monitor.alert import send_feishu_msg_with_webhook
|
||||
from internlm.utils.common import SingletonMeta
|
||||
|
||||
from .utils import get_job_key, set_env_var
|
||||
|
||||
|
||||
def send_alert_message(address: str = None, title: str = None, message: str = None):
|
||||
"""
|
||||
Send alert messages to the given Feishu webhook address in log rank.
|
||||
|
||||
Args:
|
||||
address (str): The alert address to be used to send message, defaults to None.
|
||||
title (str): The message title, defaults to None.
|
||||
message (str): The message body, defaults to None.
|
||||
"""
|
||||
|
||||
if address is not None and gpc.is_rank_for_log():
|
||||
send_feishu_msg_with_webhook(
|
||||
webhook=address,
|
||||
title=title if title else get_job_key(),
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
class MonitorTracker(Thread):
|
||||
"""
|
||||
Track job status and alert to Feishu during job training.
|
||||
|
||||
Args:
|
||||
alert_address (str): The Feishu webhook address for sending alerting messages.
|
||||
check_interval (float): The interval in seconds for monitoring checks. Defaults to 300.
|
||||
loss_spike_limit (float): The threshold for detecting loss value spikes. Defaults to 1.5.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alert_address: str,
|
||||
check_interval: float = 300,
|
||||
loss_spike_limit: float = 1.5,
|
||||
):
|
||||
super().__init__()
|
||||
self.alert_address = alert_address
|
||||
self.check_interval = check_interval
|
||||
self.loss_spike_limit = loss_spike_limit
|
||||
self.last_active_time = -1
|
||||
self.last_loss_value = -1
|
||||
self.stopped = False
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
start the monitor tracker.
|
||||
"""
|
||||
|
||||
while not self.stopped:
|
||||
try:
|
||||
self._check_stuck()
|
||||
self._check_loss_spike()
|
||||
except Exception:
|
||||
continue
|
||||
time.sleep(self.check_interval)
|
||||
|
||||
def _check_stuck(self):
|
||||
"""
|
||||
Check training status for potential stuck condition.
|
||||
"""
|
||||
|
||||
new_active_time = -1
|
||||
if os.getenv("LAST_ACTIVE_TIMESTAMP") is not None:
|
||||
new_active_time = os.getenv("LAST_ACTIVE_TIMESTAMP")
|
||||
if int(new_active_time) <= int(self.last_active_time) and new_active_time != -1:
|
||||
self._send_alert("Training may be in stuck status, please check it.")
|
||||
self.last_active_time = new_active_time
|
||||
|
||||
def _check_loss_spike(self):
|
||||
"""
|
||||
Check for loss value spikes.
|
||||
"""
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
new_loss_value = -1
|
||||
new_step_id = -1
|
||||
if os.getenv("LOSS") is not None:
|
||||
new_loss_value = os.getenv("LOSS")
|
||||
if os.getenv("STEP_ID") is not None:
|
||||
new_step_id = os.getenv("STEP_ID")
|
||||
|
||||
if (float(new_loss_value) / float(self.last_loss_value)) > self.loss_spike_limit and new_loss_value != -1:
|
||||
assert int(new_step_id) >= 0
|
||||
self._send_alert(
|
||||
f"Checking periodically: Loss spike may be happened in step {new_step_id}, "
|
||||
f"loss value from {self.last_loss_value} to {new_loss_value}, please check it."
|
||||
)
|
||||
|
||||
self.last_loss_value = new_loss_value
|
||||
|
||||
def _send_alert(self, message):
|
||||
"""
|
||||
Send alerting message to the Feishu webhook address.
|
||||
|
||||
Args:
|
||||
message (str): The alerting message to be sent.
|
||||
"""
|
||||
|
||||
send_alert_message(
|
||||
address=self.alert_address,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stop the monitor tracker.
|
||||
"""
|
||||
|
||||
self.stopped = True
|
||||
|
||||
|
||||
class MonitorManager(metaclass=SingletonMeta):
|
||||
"""
|
||||
Monitor Manager for managing monitor thread and monitoring training status.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_spike_limit: float = 1.5) -> None:
|
||||
self.monitor_thread = None
|
||||
self.loss_spike_limit = loss_spike_limit
|
||||
self.last_step_loss = -1
|
||||
|
||||
def monitor_loss_spike(self, alert_address: str = None, step_count: int = 0, cur_step_loss: float = 0.0):
|
||||
"""Check loss value, if loss spike occurs, send alert message to Feishu."""
|
||||
set_env_var(key="LOSS", value=cur_step_loss)
|
||||
set_env_var(key="STEP_ID", value=step_count)
|
||||
|
||||
if self.last_step_loss != -1 and cur_step_loss > self.loss_spike_limit * self.last_step_loss:
|
||||
send_alert_message(
|
||||
address=alert_address,
|
||||
message=(
|
||||
f"Checking step by step: Loss spike may be happened in step {step_count}, "
|
||||
f"loss value from {self.last_step_loss} to {cur_step_loss}, please check it."
|
||||
),
|
||||
)
|
||||
self.last_step_loss = cur_step_loss
|
||||
|
||||
def monitor_exception(self, alert_address: str = None, excp_info: str = None):
|
||||
"""Catch and format exception information, send alert message to Feishu."""
|
||||
filtered_trace = excp_info.split("\n")[-10:]
|
||||
format_trace = ""
|
||||
for line in filtered_trace:
|
||||
format_trace += "\n" + line
|
||||
send_alert_message(
|
||||
address=alert_address,
|
||||
message=f"Catch Exception from {socket.gethostname()} with rank id {gpc.get_global_rank()}:{format_trace}",
|
||||
)
|
||||
|
||||
def handle_sigterm(self, alert_address: str = None):
|
||||
"""Catch SIGTERM signal, and send alert message to Feishu."""
|
||||
|
||||
def sigterm_handler(sys_signal, frame):
|
||||
print("receive frame: ", frame)
|
||||
print("receive signal: ", sys_signal)
|
||||
send_alert_message(
|
||||
address=alert_address,
|
||||
message=f"Process received signal {signal} and exited.",
|
||||
)
|
||||
|
||||
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||
|
||||
def start_monitor(
|
||||
self,
|
||||
job_name: str,
|
||||
alert_address: str,
|
||||
monitor_interval_seconds: int = 300,
|
||||
loss_spike_limit: float = 1.5,
|
||||
):
|
||||
"""
|
||||
Initialize and start monitor thread for checking training job status, loss spike and so on.
|
||||
|
||||
Args:
|
||||
job_name (str): The training job name.
|
||||
alert_address (str): The Feishu webhook address for sending alert messages.
|
||||
monitor_interval_seconds (int): The time of monitor interval in seconds, defaults to 300.
|
||||
loss_spike_limit (float): The limit multiple of current loss to previous loss value, which means loss spike
|
||||
may be occurs, defaults to 1.5.
|
||||
"""
|
||||
|
||||
# initialize some variables for monitoring
|
||||
set_env_var(key="JOB_NAME", value=job_name)
|
||||
|
||||
# start a monitor thread, periodically check the training status
|
||||
self.monitor_thread = MonitorTracker(
|
||||
alert_address=alert_address,
|
||||
check_interval=monitor_interval_seconds,
|
||||
loss_spike_limit=loss_spike_limit,
|
||||
)
|
||||
|
||||
def stop_monitor(self):
|
||||
"""Stop the monitor and alert thread."""
|
||||
if self.monitor_thread is not None:
|
||||
self.monitor_thread.stop()
|
||||
|
||||
|
||||
monitor_manager = MonitorManager()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
|
||||
if alert_address is not None:
|
||||
try:
|
||||
monitor_manager.start_monitor(job_name=job_name, alert_address=alert_address)
|
||||
monitor_manager.handle_sigterm(alert_address=alert_address)
|
||||
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
|
||||
yield
|
||||
finally:
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
|
||||
)
|
||||
monitor_manager.stop_monitor()
|
||||
else:
|
||||
yield
|
|
@ -0,0 +1,32 @@
|
|||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def now_time():
|
||||
return datetime.now().strftime("%b%d_%H-%M-%S")
|
||||
|
||||
|
||||
def set_env_var(key, value):
|
||||
os.environ[str(key)] = str(value)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = "none"
|
||||
if os.getenv("SLURM_JOB_ID") is not None:
|
||||
job_id = os.getenv("SLURM_JOB_ID")
|
||||
elif os.getenv("K8S_WORKSPACE_ID") is not None:
|
||||
job_id = os.getenv("K8S_WORKSPACE_ID")
|
||||
|
||||
return job_id
|
||||
|
||||
|
||||
def get_job_name():
|
||||
job_name = f"unknown-{now_time()}"
|
||||
if os.getenv("JOB_NAME") is not None:
|
||||
job_name = os.getenv("JOB_NAME")
|
||||
|
||||
return job_name
|
||||
|
||||
|
||||
def get_job_key():
|
||||
return f"{get_job_id()}_{get_job_name()}"
|
|
@ -28,6 +28,7 @@ from internlm.solver.optimizer.utils import (
|
|||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.monitor import send_alert_message
|
||||
|
||||
from .utils import compute_norm
|
||||
|
||||
|
@ -542,6 +543,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if found_inf:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("Overflow occurs, please check it.")
|
||||
send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.")
|
||||
self._grad_store._averaged_gradients = dict()
|
||||
self.zero_grad()
|
||||
return False, None
|
||||
|
|
|
@ -34,18 +34,6 @@ def get_master_node():
|
|||
return result
|
||||
|
||||
|
||||
def get_process_rank():
|
||||
proc_rank = -1
|
||||
if os.getenv("SLURM_PROCID") is not None:
|
||||
proc_rank = int(os.getenv("SLURM_PROCID"))
|
||||
elif os.getenv("RANK") is not None:
|
||||
# In k8s env, we use $RANK.
|
||||
proc_rank = int(os.getenv("RANK"))
|
||||
|
||||
# assert proc_rank != -1, "get_process_rank cant't get right process rank!"
|
||||
return proc_rank
|
||||
|
||||
|
||||
def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
if torch.is_tensor(norm) and norm.device.type != "cuda":
|
||||
norm = norm.to(torch.cuda.current_device())
|
||||
|
|
|
@ -6,8 +6,8 @@ from tqdm import tqdm
|
|||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.metrics import AccPerplex
|
||||
from internlm.core.scheduler import SchedulerMetricHook
|
||||
from internlm.model.metrics import AccPerplex
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -90,15 +90,9 @@ def evaluate_on_val_dls(
|
|||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||
if gpc.config.model.sequence_parallel:
|
||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
|
||||
)
|
||||
else:
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||
)
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||
)
|
||||
|
||||
with switch_evaluation_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
|
@ -114,7 +108,6 @@ def evaluate_on_val_dls(
|
|||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
||||
grad_accum_batch_size = data_cfg.micro_bsz
|
||||
# import pdb; pdb.set_trace()
|
||||
with switch_evaluation_no_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
grad_accum_size=grad_accum_size,
|
||||
|
@ -170,4 +163,4 @@ def switch_sequence_parallel_mode():
|
|||
gpc.config.model.sequence_parallel = False
|
||||
yield
|
||||
finally:
|
||||
gpc.config.model.sequence_parallel = prev_mode
|
||||
gpc.config.model.sequence_parallel = prev_mode
|
||||
|
|
50
train.py
50
train.py
|
@ -30,6 +30,8 @@ from internlm.data.packed_dataset import (
|
|||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||
from internlm.model.loss import FlashGPTLMLoss
|
||||
from internlm.model.metrics import AccPerplex
|
||||
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
|
@ -37,7 +39,6 @@ from internlm.utils.common import (
|
|||
BatchSkipper,
|
||||
get_master_node,
|
||||
get_megatron_flops,
|
||||
get_process_rank,
|
||||
launch_time,
|
||||
parse_args,
|
||||
)
|
||||
|
@ -92,6 +93,15 @@ def initialize_distributed_env(config: str, launcher: str = "slurm", master_port
|
|||
|
||||
|
||||
def initialize_llm_logger(start_time: str):
|
||||
"""
|
||||
Initialize customed uniscale logger.
|
||||
|
||||
Args:
|
||||
start_time (str): The launch time of current training job.
|
||||
|
||||
Returns: The instance of uniscale logger.
|
||||
"""
|
||||
|
||||
uniscale_logger = initialize_uniscale_logger(
|
||||
job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
|
||||
)
|
||||
|
@ -213,6 +223,8 @@ def get_train_data_loader(num_worker: int = 0):
|
|||
|
||||
|
||||
def get_validation_data_loader(num_worker: int = 0):
|
||||
"""Generate and return the validation data loader."""
|
||||
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
if not data_cfg.valid_folder:
|
||||
|
@ -327,6 +339,8 @@ def record_current_batch_training_metrics(
|
|||
Print some training metrics of current batch.
|
||||
"""
|
||||
|
||||
set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
|
||||
|
||||
if success_update in (0, True):
|
||||
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
||||
if is_no_pp_or_last_stage():
|
||||
|
@ -405,12 +419,11 @@ def record_current_batch_training_metrics(
|
|||
else:
|
||||
logger.info(line)
|
||||
|
||||
# if loss spike occurs, send alert info to feishu
|
||||
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
|
||||
|
||||
|
||||
def main(args):
|
||||
# initialize distributed environment
|
||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||
assert hasattr(gpc, "config") and gpc.config is not None
|
||||
|
||||
# init setting
|
||||
skip_batches = gpc.config.data.skip_batches
|
||||
total_steps = gpc.config.data.total_steps
|
||||
|
@ -477,8 +490,8 @@ def main(args):
|
|||
model_load_path = load_model_only_folder
|
||||
else:
|
||||
logger.info(
|
||||
f"===========New Run {current_time} on host:{socket.gethostname()},"
|
||||
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
||||
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
||||
)
|
||||
|
||||
|
@ -594,6 +607,9 @@ def main(args):
|
|||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
|
||||
)
|
||||
|
||||
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
||||
record_current_batch_training_metrics(
|
||||
|
@ -646,9 +662,19 @@ def main(args):
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
hostname = socket.gethostname()
|
||||
|
||||
try:
|
||||
main(args)
|
||||
except Exception:
|
||||
print(f"Raise exception from {socket.gethostname()} with proc id: {get_process_rank()}")
|
||||
traceback.print_exc()
|
||||
# initialize distributed environment
|
||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||
assert hasattr(gpc, "config") and gpc.config is not None
|
||||
|
||||
# initialize monitor manager context
|
||||
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
|
||||
try:
|
||||
main(args)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
|
||||
exc_info=traceback.format_exc(),
|
||||
)
|
||||
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
|
||||
|
|
Loading…
Reference in New Issue