mirror of https://github.com/InternLM/InternLM
feat(monitor): support monitor and alert (#175)
* feat(monitor): support monitor and alert * feat(monitor.py): fix demo error * feat(monitor.py): move cmd monitor args to config file * feat(hybrid_zero_optim.py): if overflow occurs send alert msg * feat(monitor.py): remove alert msg filter * feat(monitor.py): optimize class MonitorTracker * feat(monitor.py): optimize code * feat(monitor.py): optimize code * feat(monitor.py): optimize code * feat(monitor.py): optimize code * feat(train.py): update print to log * style(ci): fix lint error * fix(utils/evaluation.py): remove useless code * fix(model/modeling_internlm.py): fix lint error --------- Co-authored-by: huangting4201 <huangting3@sensetime.com>pull/190/head
parent
c219065348
commit
ff0fa7659f
|
@ -202,7 +202,13 @@ def args_sanity_check():
|
||||||
if "sequence_parallel" not in gpc.config.model:
|
if "sequence_parallel" not in gpc.config.model:
|
||||||
gpc.config.model._add_item("sequence_parallel", False)
|
gpc.config.model._add_item("sequence_parallel", False)
|
||||||
else:
|
else:
|
||||||
assert not (gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False), "sequence parallel does not support use_flash_attn=False"
|
assert not (
|
||||||
|
gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
||||||
|
), "sequence parallel does not support use_flash_attn=False"
|
||||||
|
|
||||||
|
# feishu webhook address for alerting
|
||||||
|
if "alert_address" not in gpc.config:
|
||||||
|
gpc.config._add_item("alert_address", None)
|
||||||
|
|
||||||
|
|
||||||
def launch(
|
def launch(
|
||||||
|
|
|
@ -55,10 +55,10 @@ class Embedding1D(nn.Module):
|
||||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||||
|
|
||||||
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
||||||
|
|
||||||
if gpc.config.model.sequence_parallel:
|
if gpc.config.model.sequence_parallel:
|
||||||
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,11 @@ class ScaleColumnParallelLinear(nn.Linear):
|
||||||
else:
|
else:
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
return fused_dense_func_torch(
|
return fused_dense_func_torch(
|
||||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
|
input,
|
||||||
|
weight,
|
||||||
|
self.bias,
|
||||||
|
process_group=self.process_group,
|
||||||
|
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,7 +107,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
||||||
else:
|
else:
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
return fused_dense_func_torch(
|
return fused_dense_func_torch(
|
||||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
|
input,
|
||||||
|
weight,
|
||||||
|
self.bias,
|
||||||
|
process_group=self.process_group,
|
||||||
|
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -170,7 +178,13 @@ class FeedForward(nn.Module):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.w2 = ColumnParallelLinearTorch(
|
self.w2 = ColumnParallelLinearTorch(
|
||||||
in_features, hidden_features, process_group, bias, sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype
|
in_features,
|
||||||
|
hidden_features,
|
||||||
|
process_group,
|
||||||
|
bias,
|
||||||
|
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.w3 = RowParallelLinearTorch(
|
self.w3 = RowParallelLinearTorch(
|
||||||
hidden_features,
|
hidden_features,
|
||||||
|
|
|
@ -31,6 +31,7 @@ MODEL_TYPE = "INTERNLM"
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
RMSNorm = try_import_RMSNorm()
|
RMSNorm = try_import_RMSNorm()
|
||||||
|
|
||||||
|
|
||||||
class PackedFlashBaseLayer1D(nn.Module):
|
class PackedFlashBaseLayer1D(nn.Module):
|
||||||
"""
|
"""
|
||||||
1D Packed Flash Base Layer.
|
1D Packed Flash Base Layer.
|
||||||
|
@ -461,7 +462,7 @@ def build_model_with_cfg(
|
||||||
use_scaled_init: bool = True,
|
use_scaled_init: bool = True,
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
sequence_parallel: bool = False,
|
sequence_parallel: bool = False, # pylint: disable=W0613
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Builde model with config
|
Builde model with config
|
||||||
|
|
|
@ -16,6 +16,9 @@ from torch.cuda.amp import custom_bwd
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
def _split(input_, parallel_mode, dim=-1):
|
def _split(input_, parallel_mode, dim=-1):
|
||||||
|
@ -84,6 +87,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||||
def gather_forward_split_backward(input_, parallel_mode, dim):
|
def gather_forward_split_backward(input_, parallel_mode, dim):
|
||||||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
||||||
|
|
||||||
|
|
||||||
def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
|
def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
|
||||||
assert input.dtype == grad_output.dtype
|
assert input.dtype == grad_output.dtype
|
||||||
grad_weight = torch.matmul(grad_output.t(), input)
|
grad_weight = torch.matmul(grad_output.t(), input)
|
||||||
|
@ -157,10 +161,11 @@ def fused_dense_func_torch(
|
||||||
else:
|
else:
|
||||||
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
|
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
|
||||||
|
|
||||||
|
|
||||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
Split the input and keep only the corresponding chuck to the rank.
|
Split the input and keep only the corresponding chuck to the rank.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_: input matrix.
|
input_: input matrix.
|
||||||
parallel_mode: parallel mode.
|
parallel_mode: parallel mode.
|
||||||
|
@ -180,7 +185,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return _gather(grad_output, ctx.mode, ctx.dim), None, None
|
return _gather(grad_output, ctx.mode, ctx.dim), None, None
|
||||||
|
|
||||||
|
|
||||||
def split_forward_gather_backward(input_, parallel_mode, dim):
|
def split_forward_gather_backward(input_, parallel_mode, dim):
|
||||||
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
|
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
|
||||||
|
@ -189,14 +194,14 @@ def split_forward_gather_backward(input_, parallel_mode, dim):
|
||||||
def try_import_RMSNorm():
|
def try_import_RMSNorm():
|
||||||
"""
|
"""
|
||||||
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
|
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
||||||
|
|
||||||
return RMSNorm
|
return RMSNorm
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError:
|
||||||
from internlm.utils.logger import get_logger
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
|
logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
|
||||||
from internlm.model.norm import RMSNormTorch as RMSNorm
|
from internlm.model.norm import RMSNormTorch as RMSNorm
|
||||||
|
|
||||||
return RMSNorm
|
return RMSNorm
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .monitor import initialize_monitor_manager, send_alert_message
|
||||||
|
from .utils import set_env_var
|
||||||
|
|
||||||
|
__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"]
|
|
@ -0,0 +1,53 @@
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def send_feishu_msg_with_webhook(webhook: str, title: str, message: str):
|
||||||
|
"""
|
||||||
|
Use Feishu robot to send messages with the given webhook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
webhook (str): The webhook to be used to send message.
|
||||||
|
title (str): The message title.
|
||||||
|
message (str): The message body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response from the request. Or catch the exception and return None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: An exception rasied by the HTTP post request.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json;charset=utf-8"}
|
||||||
|
msg_body = {
|
||||||
|
"timestamp": int(time.time()),
|
||||||
|
"msg_type": "post",
|
||||||
|
"content": {
|
||||||
|
"post": {
|
||||||
|
"zh_cn": {
|
||||||
|
"title": title,
|
||||||
|
"content": [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tag": "text",
|
||||||
|
"text": message,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = requests.post(webhook, data=json.dumps(msg_body), headers=headers, timeout=30)
|
||||||
|
res = res.json()
|
||||||
|
print(f"Feishu webhook response: {res}")
|
||||||
|
except Exception as err: # pylint: disable=W0703
|
||||||
|
print(f"HTTP Post error: {err}")
|
||||||
|
res = None
|
||||||
|
|
||||||
|
return res
|
|
@ -0,0 +1,226 @@
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.monitor.alert import send_feishu_msg_with_webhook
|
||||||
|
from internlm.utils.common import SingletonMeta
|
||||||
|
|
||||||
|
from .utils import get_job_key, set_env_var
|
||||||
|
|
||||||
|
|
||||||
|
def send_alert_message(address: str = None, title: str = None, message: str = None):
|
||||||
|
"""
|
||||||
|
Send alert messages to the given Feishu webhook address in log rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
address (str): The alert address to be used to send message, defaults to None.
|
||||||
|
title (str): The message title, defaults to None.
|
||||||
|
message (str): The message body, defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if address is not None and gpc.is_rank_for_log():
|
||||||
|
send_feishu_msg_with_webhook(
|
||||||
|
webhook=address,
|
||||||
|
title=title if title else get_job_key(),
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MonitorTracker(Thread):
|
||||||
|
"""
|
||||||
|
Track job status and alert to Feishu during job training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alert_address (str): The Feishu webhook address for sending alerting messages.
|
||||||
|
check_interval (float): The interval in seconds for monitoring checks. Defaults to 300.
|
||||||
|
loss_spike_limit (float): The threshold for detecting loss value spikes. Defaults to 1.5.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
alert_address: str,
|
||||||
|
check_interval: float = 300,
|
||||||
|
loss_spike_limit: float = 1.5,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.alert_address = alert_address
|
||||||
|
self.check_interval = check_interval
|
||||||
|
self.loss_spike_limit = loss_spike_limit
|
||||||
|
self.last_active_time = -1
|
||||||
|
self.last_loss_value = -1
|
||||||
|
self.stopped = False
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
start the monitor tracker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
while not self.stopped:
|
||||||
|
try:
|
||||||
|
self._check_stuck()
|
||||||
|
self._check_loss_spike()
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
time.sleep(self.check_interval)
|
||||||
|
|
||||||
|
def _check_stuck(self):
|
||||||
|
"""
|
||||||
|
Check training status for potential stuck condition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
new_active_time = -1
|
||||||
|
if os.getenv("LAST_ACTIVE_TIMESTAMP") is not None:
|
||||||
|
new_active_time = os.getenv("LAST_ACTIVE_TIMESTAMP")
|
||||||
|
if int(new_active_time) <= int(self.last_active_time) and new_active_time != -1:
|
||||||
|
self._send_alert("Training may be in stuck status, please check it.")
|
||||||
|
self.last_active_time = new_active_time
|
||||||
|
|
||||||
|
def _check_loss_spike(self):
|
||||||
|
"""
|
||||||
|
Check for loss value spikes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
new_loss_value = -1
|
||||||
|
new_step_id = -1
|
||||||
|
if os.getenv("LOSS") is not None:
|
||||||
|
new_loss_value = os.getenv("LOSS")
|
||||||
|
if os.getenv("STEP_ID") is not None:
|
||||||
|
new_step_id = os.getenv("STEP_ID")
|
||||||
|
|
||||||
|
if (float(new_loss_value) / float(self.last_loss_value)) > self.loss_spike_limit and new_loss_value != -1:
|
||||||
|
assert int(new_step_id) >= 0
|
||||||
|
self._send_alert(
|
||||||
|
f"Checking periodically: Loss spike may be happened in step {new_step_id}, "
|
||||||
|
f"loss value from {self.last_loss_value} to {new_loss_value}, please check it."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.last_loss_value = new_loss_value
|
||||||
|
|
||||||
|
def _send_alert(self, message):
|
||||||
|
"""
|
||||||
|
Send alerting message to the Feishu webhook address.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The alerting message to be sent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
send_alert_message(
|
||||||
|
address=self.alert_address,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""
|
||||||
|
Stop the monitor tracker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.stopped = True
|
||||||
|
|
||||||
|
|
||||||
|
class MonitorManager(metaclass=SingletonMeta):
|
||||||
|
"""
|
||||||
|
Monitor Manager for managing monitor thread and monitoring training status.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loss_spike_limit: float = 1.5) -> None:
|
||||||
|
self.monitor_thread = None
|
||||||
|
self.loss_spike_limit = loss_spike_limit
|
||||||
|
self.last_step_loss = -1
|
||||||
|
|
||||||
|
def monitor_loss_spike(self, alert_address: str = None, step_count: int = 0, cur_step_loss: float = 0.0):
|
||||||
|
"""Check loss value, if loss spike occurs, send alert message to Feishu."""
|
||||||
|
set_env_var(key="LOSS", value=cur_step_loss)
|
||||||
|
set_env_var(key="STEP_ID", value=step_count)
|
||||||
|
|
||||||
|
if self.last_step_loss != -1 and cur_step_loss > self.loss_spike_limit * self.last_step_loss:
|
||||||
|
send_alert_message(
|
||||||
|
address=alert_address,
|
||||||
|
message=(
|
||||||
|
f"Checking step by step: Loss spike may be happened in step {step_count}, "
|
||||||
|
f"loss value from {self.last_step_loss} to {cur_step_loss}, please check it."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.last_step_loss = cur_step_loss
|
||||||
|
|
||||||
|
def monitor_exception(self, alert_address: str = None, excp_info: str = None):
|
||||||
|
"""Catch and format exception information, send alert message to Feishu."""
|
||||||
|
filtered_trace = excp_info.split("\n")[-10:]
|
||||||
|
format_trace = ""
|
||||||
|
for line in filtered_trace:
|
||||||
|
format_trace += "\n" + line
|
||||||
|
send_alert_message(
|
||||||
|
address=alert_address,
|
||||||
|
message=f"Catch Exception from {socket.gethostname()} with rank id {gpc.get_global_rank()}:{format_trace}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle_sigterm(self, alert_address: str = None):
|
||||||
|
"""Catch SIGTERM signal, and send alert message to Feishu."""
|
||||||
|
|
||||||
|
def sigterm_handler(sys_signal, frame):
|
||||||
|
print("receive frame: ", frame)
|
||||||
|
print("receive signal: ", sys_signal)
|
||||||
|
send_alert_message(
|
||||||
|
address=alert_address,
|
||||||
|
message=f"Process received signal {signal} and exited.",
|
||||||
|
)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||||
|
|
||||||
|
def start_monitor(
|
||||||
|
self,
|
||||||
|
job_name: str,
|
||||||
|
alert_address: str,
|
||||||
|
monitor_interval_seconds: int = 300,
|
||||||
|
loss_spike_limit: float = 1.5,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize and start monitor thread for checking training job status, loss spike and so on.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_name (str): The training job name.
|
||||||
|
alert_address (str): The Feishu webhook address for sending alert messages.
|
||||||
|
monitor_interval_seconds (int): The time of monitor interval in seconds, defaults to 300.
|
||||||
|
loss_spike_limit (float): The limit multiple of current loss to previous loss value, which means loss spike
|
||||||
|
may be occurs, defaults to 1.5.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# initialize some variables for monitoring
|
||||||
|
set_env_var(key="JOB_NAME", value=job_name)
|
||||||
|
|
||||||
|
# start a monitor thread, periodically check the training status
|
||||||
|
self.monitor_thread = MonitorTracker(
|
||||||
|
alert_address=alert_address,
|
||||||
|
check_interval=monitor_interval_seconds,
|
||||||
|
loss_spike_limit=loss_spike_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def stop_monitor(self):
|
||||||
|
"""Stop the monitor and alert thread."""
|
||||||
|
if self.monitor_thread is not None:
|
||||||
|
self.monitor_thread.stop()
|
||||||
|
|
||||||
|
|
||||||
|
monitor_manager = MonitorManager()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
|
||||||
|
if alert_address is not None:
|
||||||
|
try:
|
||||||
|
monitor_manager.start_monitor(job_name=job_name, alert_address=alert_address)
|
||||||
|
monitor_manager.handle_sigterm(alert_address=alert_address)
|
||||||
|
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
send_alert_message(
|
||||||
|
address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
|
||||||
|
)
|
||||||
|
monitor_manager.stop_monitor()
|
||||||
|
else:
|
||||||
|
yield
|
|
@ -0,0 +1,32 @@
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def now_time():
|
||||||
|
return datetime.now().strftime("%b%d_%H-%M-%S")
|
||||||
|
|
||||||
|
|
||||||
|
def set_env_var(key, value):
|
||||||
|
os.environ[str(key)] = str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_id():
|
||||||
|
job_id = "none"
|
||||||
|
if os.getenv("SLURM_JOB_ID") is not None:
|
||||||
|
job_id = os.getenv("SLURM_JOB_ID")
|
||||||
|
elif os.getenv("K8S_WORKSPACE_ID") is not None:
|
||||||
|
job_id = os.getenv("K8S_WORKSPACE_ID")
|
||||||
|
|
||||||
|
return job_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_name():
|
||||||
|
job_name = f"unknown-{now_time()}"
|
||||||
|
if os.getenv("JOB_NAME") is not None:
|
||||||
|
job_name = os.getenv("JOB_NAME")
|
||||||
|
|
||||||
|
return job_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_key():
|
||||||
|
return f"{get_job_id()}_{get_job_name()}"
|
|
@ -28,6 +28,7 @@ from internlm.solver.optimizer.utils import (
|
||||||
from internlm.utils.common import get_current_device
|
from internlm.utils.common import get_current_device
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
|
from internlm.monitor import send_alert_message
|
||||||
|
|
||||||
from .utils import compute_norm
|
from .utils import compute_norm
|
||||||
|
|
||||||
|
@ -542,6 +543,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if found_inf:
|
if found_inf:
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.warning("Overflow occurs, please check it.")
|
logger.warning("Overflow occurs, please check it.")
|
||||||
|
send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.")
|
||||||
self._grad_store._averaged_gradients = dict()
|
self._grad_store._averaged_gradients = dict()
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return False, None
|
return False, None
|
||||||
|
|
|
@ -34,18 +34,6 @@ def get_master_node():
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_process_rank():
|
|
||||||
proc_rank = -1
|
|
||||||
if os.getenv("SLURM_PROCID") is not None:
|
|
||||||
proc_rank = int(os.getenv("SLURM_PROCID"))
|
|
||||||
elif os.getenv("RANK") is not None:
|
|
||||||
# In k8s env, we use $RANK.
|
|
||||||
proc_rank = int(os.getenv("RANK"))
|
|
||||||
|
|
||||||
# assert proc_rank != -1, "get_process_rank cant't get right process rank!"
|
|
||||||
return proc_rank
|
|
||||||
|
|
||||||
|
|
||||||
def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||||
if torch.is_tensor(norm) and norm.device.type != "cuda":
|
if torch.is_tensor(norm) and norm.device.type != "cuda":
|
||||||
norm = norm.to(torch.cuda.current_device())
|
norm = norm.to(torch.cuda.current_device())
|
||||||
|
|
|
@ -6,8 +6,8 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.metrics import AccPerplex
|
|
||||||
from internlm.core.scheduler import SchedulerMetricHook
|
from internlm.core.scheduler import SchedulerMetricHook
|
||||||
|
from internlm.model.metrics import AccPerplex
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -90,15 +90,9 @@ def evaluate_on_val_dls(
|
||||||
total_val_bsz = len(batch[1])
|
total_val_bsz = len(batch[1])
|
||||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||||
if gpc.config.model.sequence_parallel:
|
tensor_shape = torch.Size(
|
||||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||||
tensor_shape = torch.Size(
|
)
|
||||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tensor_shape = torch.Size(
|
|
||||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
|
||||||
)
|
|
||||||
|
|
||||||
with switch_evaluation_pipeline_scheduler(
|
with switch_evaluation_pipeline_scheduler(
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
|
@ -114,7 +108,6 @@ def evaluate_on_val_dls(
|
||||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||||
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
||||||
grad_accum_batch_size = data_cfg.micro_bsz
|
grad_accum_batch_size = data_cfg.micro_bsz
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
with switch_evaluation_no_pipeline_scheduler(
|
with switch_evaluation_no_pipeline_scheduler(
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
grad_accum_size=grad_accum_size,
|
grad_accum_size=grad_accum_size,
|
||||||
|
@ -170,4 +163,4 @@ def switch_sequence_parallel_mode():
|
||||||
gpc.config.model.sequence_parallel = False
|
gpc.config.model.sequence_parallel = False
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
gpc.config.model.sequence_parallel = prev_mode
|
gpc.config.model.sequence_parallel = prev_mode
|
||||||
|
|
50
train.py
50
train.py
|
@ -30,6 +30,8 @@ from internlm.data.packed_dataset import (
|
||||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||||
from internlm.model.loss import FlashGPTLMLoss
|
from internlm.model.loss import FlashGPTLMLoss
|
||||||
from internlm.model.metrics import AccPerplex
|
from internlm.model.metrics import AccPerplex
|
||||||
|
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
|
||||||
|
from internlm.monitor.monitor import monitor_manager as mm
|
||||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||||
|
@ -37,7 +39,6 @@ from internlm.utils.common import (
|
||||||
BatchSkipper,
|
BatchSkipper,
|
||||||
get_master_node,
|
get_master_node,
|
||||||
get_megatron_flops,
|
get_megatron_flops,
|
||||||
get_process_rank,
|
|
||||||
launch_time,
|
launch_time,
|
||||||
parse_args,
|
parse_args,
|
||||||
)
|
)
|
||||||
|
@ -92,6 +93,15 @@ def initialize_distributed_env(config: str, launcher: str = "slurm", master_port
|
||||||
|
|
||||||
|
|
||||||
def initialize_llm_logger(start_time: str):
|
def initialize_llm_logger(start_time: str):
|
||||||
|
"""
|
||||||
|
Initialize customed uniscale logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_time (str): The launch time of current training job.
|
||||||
|
|
||||||
|
Returns: The instance of uniscale logger.
|
||||||
|
"""
|
||||||
|
|
||||||
uniscale_logger = initialize_uniscale_logger(
|
uniscale_logger = initialize_uniscale_logger(
|
||||||
job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
|
job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
|
||||||
)
|
)
|
||||||
|
@ -213,6 +223,8 @@ def get_train_data_loader(num_worker: int = 0):
|
||||||
|
|
||||||
|
|
||||||
def get_validation_data_loader(num_worker: int = 0):
|
def get_validation_data_loader(num_worker: int = 0):
|
||||||
|
"""Generate and return the validation data loader."""
|
||||||
|
|
||||||
data_cfg = gpc.config.data
|
data_cfg = gpc.config.data
|
||||||
|
|
||||||
if not data_cfg.valid_folder:
|
if not data_cfg.valid_folder:
|
||||||
|
@ -327,6 +339,8 @@ def record_current_batch_training_metrics(
|
||||||
Print some training metrics of current batch.
|
Print some training metrics of current batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
|
||||||
|
|
||||||
if success_update in (0, True):
|
if success_update in (0, True):
|
||||||
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
||||||
if is_no_pp_or_last_stage():
|
if is_no_pp_or_last_stage():
|
||||||
|
@ -405,12 +419,11 @@ def record_current_batch_training_metrics(
|
||||||
else:
|
else:
|
||||||
logger.info(line)
|
logger.info(line)
|
||||||
|
|
||||||
|
# if loss spike occurs, send alert info to feishu
|
||||||
|
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# initialize distributed environment
|
|
||||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
|
||||||
assert hasattr(gpc, "config") and gpc.config is not None
|
|
||||||
|
|
||||||
# init setting
|
# init setting
|
||||||
skip_batches = gpc.config.data.skip_batches
|
skip_batches = gpc.config.data.skip_batches
|
||||||
total_steps = gpc.config.data.total_steps
|
total_steps = gpc.config.data.total_steps
|
||||||
|
@ -477,8 +490,8 @@ def main(args):
|
||||||
model_load_path = load_model_only_folder
|
model_load_path = load_model_only_folder
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"===========New Run {current_time} on host:{socket.gethostname()},"
|
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
||||||
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||||
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -594,6 +607,9 @@ def main(args):
|
||||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||||
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||||
|
send_alert_message(
|
||||||
|
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
|
||||||
|
)
|
||||||
|
|
||||||
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
||||||
record_current_batch_training_metrics(
|
record_current_batch_training_metrics(
|
||||||
|
@ -646,9 +662,19 @@ def main(args):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
hostname = socket.gethostname()
|
||||||
|
|
||||||
try:
|
# initialize distributed environment
|
||||||
main(args)
|
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||||
except Exception:
|
assert hasattr(gpc, "config") and gpc.config is not None
|
||||||
print(f"Raise exception from {socket.gethostname()} with proc id: {get_process_rank()}")
|
|
||||||
traceback.print_exc()
|
# initialize monitor manager context
|
||||||
|
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
|
||||||
|
try:
|
||||||
|
main(args)
|
||||||
|
except Exception:
|
||||||
|
logger.error(
|
||||||
|
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
|
||||||
|
exc_info=traceback.format_exc(),
|
||||||
|
)
|
||||||
|
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
|
||||||
|
|
Loading…
Reference in New Issue