mirror of https://github.com/InternLM/InternLM
Merge branch 'develop' into feature_add_moe
commit
9ad7942568
|
|
@ -7,22 +7,29 @@ MLP_RATIO = 8 / 3
|
||||||
NUM_LAYER = 32
|
NUM_LAYER = 32
|
||||||
VOCAB_SIZE = 103168
|
VOCAB_SIZE = 103168
|
||||||
|
|
||||||
|
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||||
# Ckpt folder format:
|
# Ckpt folder format:
|
||||||
# fs: 'local:/mnt/nfs/XXX'
|
# fs: 'local:/mnt/nfs/XXX'
|
||||||
# oss: 'boto3:s3://model_weights/XXX'
|
|
||||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
|
||||||
SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
||||||
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
|
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
|
||||||
|
|
||||||
|
# boto3 Ckpt folder format:
|
||||||
|
# import os
|
||||||
|
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
|
||||||
|
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
|
||||||
|
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
|
||||||
|
CHECKPOINT_EVERY = 50
|
||||||
ckpt = dict(
|
ckpt = dict(
|
||||||
# Path to save training ckpt.
|
enable_save_ckpt=False, # enable ckpt save.
|
||||||
save_ckpt_folder=SAVE_CKPT_FOLDER,
|
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
|
||||||
# Path to continue training ckpt (load model weights and scheduler/context states).
|
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
|
||||||
# load_ckpt_folder=LOAD_CKPT_FOLDER,
|
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
|
||||||
# Path to initialize with given model weights.
|
load_optimizer=True, # Wheter to load optimizer states when continuing training.
|
||||||
# load_model_only_folder=MODEL_ONLY_FOLDER,
|
checkpoint_every=CHECKPOINT_EVERY,
|
||||||
checkpoint_every=50,
|
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
||||||
# Wheter to load optimizer states when continuing training.
|
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
||||||
load_optimizer=True,
|
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
|
||||||
|
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
||||||
)
|
)
|
||||||
|
|
||||||
TRAIN_FOLDER = "/path/to/dataset"
|
TRAIN_FOLDER = "/path/to/dataset"
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ import torch
|
||||||
from internlm.core.context import Config
|
from internlm.core.context import Config
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
from internlm.utils.storage_manager import init_storage_manager
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
@ -122,20 +123,44 @@ def args_sanity_check():
|
||||||
if "load_model_only_folder" not in gpc.config.ckpt:
|
if "load_model_only_folder" not in gpc.config.ckpt:
|
||||||
gpc.config.ckpt._add_item("load_model_only_folder", None)
|
gpc.config.ckpt._add_item("load_model_only_folder", None)
|
||||||
|
|
||||||
|
if "async_upload" not in gpc.config.ckpt:
|
||||||
|
gpc.config.ckpt._add_item("async_upload", False)
|
||||||
|
else:
|
||||||
|
if gpc.config.ckpt.async_upload:
|
||||||
|
assert "save_ckpt_folder" in gpc.config.ckpt
|
||||||
|
if "boto3:" not in gpc.config.ckpt.save_ckpt_folder:
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.warning(
|
||||||
|
"Storing ckpt on file system does not support asynchronous storage, will use sync save!"
|
||||||
|
)
|
||||||
|
gpc.config.ckpt.async_upload = False
|
||||||
|
else:
|
||||||
|
if "async_upload_tmp_folder" not in gpc.config.ckpt:
|
||||||
|
gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
|
||||||
|
|
||||||
|
if "snapshot_ckpt_folder" not in gpc.config.ckpt:
|
||||||
|
gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder), "snapshot")
|
||||||
|
|
||||||
|
if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"):
|
||||||
|
gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2)
|
||||||
|
assert gpc.config.ckpt.oss_snapshot_freq > 0
|
||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None
|
gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None
|
||||||
), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time."
|
), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time."
|
||||||
|
|
||||||
gpc.config.ckpt._add_item(
|
if "enable_save_ckpt" not in gpc.config.ckpt:
|
||||||
"enable_ckpt", gpc.config.ckpt.save_ckpt_folder is not None and gpc.config.ckpt.checkpoint_every > 0
|
gpc.config.ckpt._add_item("enable_save_ckpt", False)
|
||||||
)
|
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
|
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
|
||||||
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_ckpt}")
|
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_save_ckpt}")
|
||||||
logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}")
|
logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}")
|
||||||
logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}")
|
logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}")
|
||||||
|
|
||||||
|
# initialization storage manager
|
||||||
|
init_storage_manager(gpc.config.ckpt)
|
||||||
|
|
||||||
# tensorboard writer config
|
# tensorboard writer config
|
||||||
if "enable_tb" not in gpc.config:
|
if "enable_tb" not in gpc.config:
|
||||||
gpc.config._add_item("enable_tb", True)
|
gpc.config._add_item("enable_tb", True)
|
||||||
|
|
@ -202,7 +227,13 @@ def args_sanity_check():
|
||||||
if "sequence_parallel" not in gpc.config.model:
|
if "sequence_parallel" not in gpc.config.model:
|
||||||
gpc.config.model._add_item("sequence_parallel", False)
|
gpc.config.model._add_item("sequence_parallel", False)
|
||||||
else:
|
else:
|
||||||
assert not (gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False), "sequence parallel does not support use_flash_attn=False"
|
assert not (
|
||||||
|
gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
||||||
|
), "sequence parallel does not support use_flash_attn=False"
|
||||||
|
|
||||||
|
# feishu webhook address for alerting
|
||||||
|
if "alert_address" not in gpc.config:
|
||||||
|
gpc.config._add_item("alert_address", None)
|
||||||
|
|
||||||
|
|
||||||
def launch(
|
def launch(
|
||||||
|
|
|
||||||
|
|
@ -55,10 +55,10 @@ class Embedding1D(nn.Module):
|
||||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||||
|
|
||||||
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
||||||
|
|
||||||
if gpc.config.model.sequence_parallel:
|
if gpc.config.model.sequence_parallel:
|
||||||
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,11 @@ class ScaleColumnParallelLinear(nn.Linear):
|
||||||
else:
|
else:
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
return fused_dense_func_torch(
|
return fused_dense_func_torch(
|
||||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
|
input,
|
||||||
|
weight,
|
||||||
|
self.bias,
|
||||||
|
process_group=self.process_group,
|
||||||
|
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -103,7 +107,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
||||||
else:
|
else:
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
return fused_dense_func_torch(
|
return fused_dense_func_torch(
|
||||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
|
input,
|
||||||
|
weight,
|
||||||
|
self.bias,
|
||||||
|
process_group=self.process_group,
|
||||||
|
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -170,7 +178,13 @@ class FeedForward(nn.Module):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.w2 = ColumnParallelLinearTorch(
|
self.w2 = ColumnParallelLinearTorch(
|
||||||
in_features, hidden_features, process_group, bias, sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype
|
in_features,
|
||||||
|
hidden_features,
|
||||||
|
process_group,
|
||||||
|
bias,
|
||||||
|
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.w3 = RowParallelLinearTorch(
|
self.w3 = RowParallelLinearTorch(
|
||||||
hidden_features,
|
hidden_features,
|
||||||
|
|
|
||||||
|
|
@ -497,7 +497,7 @@ def build_model_with_cfg(
|
||||||
use_scaled_init: bool = True,
|
use_scaled_init: bool = True,
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
sequence_parallel: bool = False,
|
sequence_parallel: bool = False, # pylint: disable=W0613
|
||||||
num_experts: int = 1,
|
num_experts: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,9 @@ from torch.cuda.amp import custom_bwd
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
def _split(input_, parallel_mode, dim=-1):
|
def _split(input_, parallel_mode, dim=-1):
|
||||||
|
|
@ -84,6 +87,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||||
def gather_forward_split_backward(input_, parallel_mode, dim):
|
def gather_forward_split_backward(input_, parallel_mode, dim):
|
||||||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
||||||
|
|
||||||
|
|
||||||
def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
|
def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
|
||||||
assert input.dtype == grad_output.dtype
|
assert input.dtype == grad_output.dtype
|
||||||
grad_weight = torch.matmul(grad_output.t(), input)
|
grad_weight = torch.matmul(grad_output.t(), input)
|
||||||
|
|
@ -157,10 +161,11 @@ def fused_dense_func_torch(
|
||||||
else:
|
else:
|
||||||
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
|
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
|
||||||
|
|
||||||
|
|
||||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
Split the input and keep only the corresponding chuck to the rank.
|
Split the input and keep only the corresponding chuck to the rank.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_: input matrix.
|
input_: input matrix.
|
||||||
parallel_mode: parallel mode.
|
parallel_mode: parallel mode.
|
||||||
|
|
@ -180,7 +185,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return _gather(grad_output, ctx.mode, ctx.dim), None, None
|
return _gather(grad_output, ctx.mode, ctx.dim), None, None
|
||||||
|
|
||||||
|
|
||||||
def split_forward_gather_backward(input_, parallel_mode, dim):
|
def split_forward_gather_backward(input_, parallel_mode, dim):
|
||||||
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
|
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
|
||||||
|
|
@ -189,14 +194,14 @@ def split_forward_gather_backward(input_, parallel_mode, dim):
|
||||||
def try_import_RMSNorm():
|
def try_import_RMSNorm():
|
||||||
"""
|
"""
|
||||||
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
|
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
||||||
|
|
||||||
return RMSNorm
|
return RMSNorm
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError:
|
||||||
from internlm.utils.logger import get_logger
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
|
logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
|
||||||
from internlm.model.norm import RMSNormTorch as RMSNorm
|
from internlm.model.norm import RMSNormTorch as RMSNorm
|
||||||
|
|
||||||
return RMSNorm
|
return RMSNorm
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .monitor import initialize_monitor_manager, send_alert_message
|
||||||
|
from .utils import set_env_var
|
||||||
|
|
||||||
|
__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"]
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def send_feishu_msg_with_webhook(webhook: str, title: str, message: str):
|
||||||
|
"""
|
||||||
|
Use Feishu robot to send messages with the given webhook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
webhook (str): The webhook to be used to send message.
|
||||||
|
title (str): The message title.
|
||||||
|
message (str): The message body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response from the request. Or catch the exception and return None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: An exception rasied by the HTTP post request.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json;charset=utf-8"}
|
||||||
|
msg_body = {
|
||||||
|
"timestamp": int(time.time()),
|
||||||
|
"msg_type": "post",
|
||||||
|
"content": {
|
||||||
|
"post": {
|
||||||
|
"zh_cn": {
|
||||||
|
"title": title,
|
||||||
|
"content": [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tag": "text",
|
||||||
|
"text": message,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = requests.post(webhook, data=json.dumps(msg_body), headers=headers, timeout=30)
|
||||||
|
res = res.json()
|
||||||
|
print(f"Feishu webhook response: {res}")
|
||||||
|
except Exception as err: # pylint: disable=W0703
|
||||||
|
print(f"HTTP Post error: {err}")
|
||||||
|
res = None
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
@ -0,0 +1,226 @@
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.monitor.alert import send_feishu_msg_with_webhook
|
||||||
|
from internlm.utils.common import SingletonMeta
|
||||||
|
|
||||||
|
from .utils import get_job_key, set_env_var
|
||||||
|
|
||||||
|
|
||||||
|
def send_alert_message(address: str = None, title: str = None, message: str = None):
|
||||||
|
"""
|
||||||
|
Send alert messages to the given Feishu webhook address in log rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
address (str): The alert address to be used to send message, defaults to None.
|
||||||
|
title (str): The message title, defaults to None.
|
||||||
|
message (str): The message body, defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if address is not None and gpc.is_rank_for_log():
|
||||||
|
send_feishu_msg_with_webhook(
|
||||||
|
webhook=address,
|
||||||
|
title=title if title else get_job_key(),
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MonitorTracker(Thread):
|
||||||
|
"""
|
||||||
|
Track job status and alert to Feishu during job training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alert_address (str): The Feishu webhook address for sending alerting messages.
|
||||||
|
check_interval (float): The interval in seconds for monitoring checks. Defaults to 300.
|
||||||
|
loss_spike_limit (float): The threshold for detecting loss value spikes. Defaults to 1.5.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
alert_address: str,
|
||||||
|
check_interval: float = 300,
|
||||||
|
loss_spike_limit: float = 1.5,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.alert_address = alert_address
|
||||||
|
self.check_interval = check_interval
|
||||||
|
self.loss_spike_limit = loss_spike_limit
|
||||||
|
self.last_active_time = -1
|
||||||
|
self.last_loss_value = -1
|
||||||
|
self.stopped = False
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
start the monitor tracker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
while not self.stopped:
|
||||||
|
try:
|
||||||
|
self._check_stuck()
|
||||||
|
self._check_loss_spike()
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
time.sleep(self.check_interval)
|
||||||
|
|
||||||
|
def _check_stuck(self):
|
||||||
|
"""
|
||||||
|
Check training status for potential stuck condition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
new_active_time = -1
|
||||||
|
if os.getenv("LAST_ACTIVE_TIMESTAMP") is not None:
|
||||||
|
new_active_time = os.getenv("LAST_ACTIVE_TIMESTAMP")
|
||||||
|
if int(new_active_time) <= int(self.last_active_time) and new_active_time != -1:
|
||||||
|
self._send_alert("Training may be in stuck status, please check it.")
|
||||||
|
self.last_active_time = new_active_time
|
||||||
|
|
||||||
|
def _check_loss_spike(self):
|
||||||
|
"""
|
||||||
|
Check for loss value spikes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
new_loss_value = -1
|
||||||
|
new_step_id = -1
|
||||||
|
if os.getenv("LOSS") is not None:
|
||||||
|
new_loss_value = os.getenv("LOSS")
|
||||||
|
if os.getenv("STEP_ID") is not None:
|
||||||
|
new_step_id = os.getenv("STEP_ID")
|
||||||
|
|
||||||
|
if (float(new_loss_value) / float(self.last_loss_value)) > self.loss_spike_limit and new_loss_value != -1:
|
||||||
|
assert int(new_step_id) >= 0
|
||||||
|
self._send_alert(
|
||||||
|
f"Checking periodically: Loss spike may be happened in step {new_step_id}, "
|
||||||
|
f"loss value from {self.last_loss_value} to {new_loss_value}, please check it."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.last_loss_value = new_loss_value
|
||||||
|
|
||||||
|
def _send_alert(self, message):
|
||||||
|
"""
|
||||||
|
Send alerting message to the Feishu webhook address.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The alerting message to be sent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
send_alert_message(
|
||||||
|
address=self.alert_address,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""
|
||||||
|
Stop the monitor tracker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.stopped = True
|
||||||
|
|
||||||
|
|
||||||
|
class MonitorManager(metaclass=SingletonMeta):
|
||||||
|
"""
|
||||||
|
Monitor Manager for managing monitor thread and monitoring training status.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loss_spike_limit: float = 1.5) -> None:
|
||||||
|
self.monitor_thread = None
|
||||||
|
self.loss_spike_limit = loss_spike_limit
|
||||||
|
self.last_step_loss = -1
|
||||||
|
|
||||||
|
def monitor_loss_spike(self, alert_address: str = None, step_count: int = 0, cur_step_loss: float = 0.0):
|
||||||
|
"""Check loss value, if loss spike occurs, send alert message to Feishu."""
|
||||||
|
set_env_var(key="LOSS", value=cur_step_loss)
|
||||||
|
set_env_var(key="STEP_ID", value=step_count)
|
||||||
|
|
||||||
|
if self.last_step_loss != -1 and cur_step_loss > self.loss_spike_limit * self.last_step_loss:
|
||||||
|
send_alert_message(
|
||||||
|
address=alert_address,
|
||||||
|
message=(
|
||||||
|
f"Checking step by step: Loss spike may be happened in step {step_count}, "
|
||||||
|
f"loss value from {self.last_step_loss} to {cur_step_loss}, please check it."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.last_step_loss = cur_step_loss
|
||||||
|
|
||||||
|
def monitor_exception(self, alert_address: str = None, excp_info: str = None):
|
||||||
|
"""Catch and format exception information, send alert message to Feishu."""
|
||||||
|
filtered_trace = excp_info.split("\n")[-10:]
|
||||||
|
format_trace = ""
|
||||||
|
for line in filtered_trace:
|
||||||
|
format_trace += "\n" + line
|
||||||
|
send_alert_message(
|
||||||
|
address=alert_address,
|
||||||
|
message=f"Catch Exception from {socket.gethostname()} with rank id {gpc.get_global_rank()}:{format_trace}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle_sigterm(self, alert_address: str = None):
|
||||||
|
"""Catch SIGTERM signal, and send alert message to Feishu."""
|
||||||
|
|
||||||
|
def sigterm_handler(sys_signal, frame):
|
||||||
|
print("receive frame: ", frame)
|
||||||
|
print("receive signal: ", sys_signal)
|
||||||
|
send_alert_message(
|
||||||
|
address=alert_address,
|
||||||
|
message=f"Process received signal {signal} and exited.",
|
||||||
|
)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||||
|
|
||||||
|
def start_monitor(
|
||||||
|
self,
|
||||||
|
job_name: str,
|
||||||
|
alert_address: str,
|
||||||
|
monitor_interval_seconds: int = 300,
|
||||||
|
loss_spike_limit: float = 1.5,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize and start monitor thread for checking training job status, loss spike and so on.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_name (str): The training job name.
|
||||||
|
alert_address (str): The Feishu webhook address for sending alert messages.
|
||||||
|
monitor_interval_seconds (int): The time of monitor interval in seconds, defaults to 300.
|
||||||
|
loss_spike_limit (float): The limit multiple of current loss to previous loss value, which means loss spike
|
||||||
|
may be occurs, defaults to 1.5.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# initialize some variables for monitoring
|
||||||
|
set_env_var(key="JOB_NAME", value=job_name)
|
||||||
|
|
||||||
|
# start a monitor thread, periodically check the training status
|
||||||
|
self.monitor_thread = MonitorTracker(
|
||||||
|
alert_address=alert_address,
|
||||||
|
check_interval=monitor_interval_seconds,
|
||||||
|
loss_spike_limit=loss_spike_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def stop_monitor(self):
|
||||||
|
"""Stop the monitor and alert thread."""
|
||||||
|
if self.monitor_thread is not None:
|
||||||
|
self.monitor_thread.stop()
|
||||||
|
|
||||||
|
|
||||||
|
monitor_manager = MonitorManager()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
|
||||||
|
if alert_address is not None:
|
||||||
|
try:
|
||||||
|
monitor_manager.start_monitor(job_name=job_name, alert_address=alert_address)
|
||||||
|
monitor_manager.handle_sigterm(alert_address=alert_address)
|
||||||
|
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
send_alert_message(
|
||||||
|
address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
|
||||||
|
)
|
||||||
|
monitor_manager.stop_monitor()
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def now_time():
|
||||||
|
return datetime.now().strftime("%b%d_%H-%M-%S")
|
||||||
|
|
||||||
|
|
||||||
|
def set_env_var(key, value):
|
||||||
|
os.environ[str(key)] = str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_id():
|
||||||
|
job_id = "none"
|
||||||
|
if os.getenv("SLURM_JOB_ID") is not None:
|
||||||
|
job_id = os.getenv("SLURM_JOB_ID")
|
||||||
|
elif os.getenv("K8S_WORKSPACE_ID") is not None:
|
||||||
|
job_id = os.getenv("K8S_WORKSPACE_ID")
|
||||||
|
|
||||||
|
return job_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_name():
|
||||||
|
job_name = f"unknown-{now_time()}"
|
||||||
|
if os.getenv("JOB_NAME") is not None:
|
||||||
|
job_name = os.getenv("JOB_NAME")
|
||||||
|
|
||||||
|
return job_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_key():
|
||||||
|
return f"{get_job_id()}_{get_job_name()}"
|
||||||
|
|
@ -29,6 +29,7 @@ from internlm.solver.optimizer.utils import (
|
||||||
from internlm.utils.common import get_current_device
|
from internlm.utils.common import get_current_device
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
|
from internlm.monitor import send_alert_message
|
||||||
|
|
||||||
from .utils import compute_norm
|
from .utils import compute_norm
|
||||||
|
|
||||||
|
|
@ -543,6 +544,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if found_inf:
|
if found_inf:
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.warning("Overflow occurs, please check it.")
|
logger.warning("Overflow occurs, please check it.")
|
||||||
|
send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.")
|
||||||
self._grad_store._averaged_gradients = dict()
|
self._grad_store._averaged_gradients = dict()
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return False, None
|
return False, None
|
||||||
|
|
|
||||||
|
|
@ -34,18 +34,6 @@ def get_master_node():
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_process_rank():
|
|
||||||
proc_rank = -1
|
|
||||||
if os.getenv("SLURM_PROCID") is not None:
|
|
||||||
proc_rank = int(os.getenv("SLURM_PROCID"))
|
|
||||||
elif os.getenv("RANK") is not None:
|
|
||||||
# In k8s env, we use $RANK.
|
|
||||||
proc_rank = int(os.getenv("RANK"))
|
|
||||||
|
|
||||||
# assert proc_rank != -1, "get_process_rank cant't get right process rank!"
|
|
||||||
return proc_rank
|
|
||||||
|
|
||||||
|
|
||||||
def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||||
if torch.is_tensor(norm) and norm.device.type != "cuda":
|
if torch.is_tensor(norm) and norm.device.type != "cuda":
|
||||||
norm = norm.to(torch.cuda.current_device())
|
norm = norm.to(torch.cuda.current_device())
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.metrics import AccPerplex
|
|
||||||
from internlm.core.scheduler import SchedulerMetricHook
|
from internlm.core.scheduler import SchedulerMetricHook
|
||||||
|
from internlm.model.metrics import AccPerplex
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|
@ -90,15 +90,9 @@ def evaluate_on_val_dls(
|
||||||
total_val_bsz = len(batch[1])
|
total_val_bsz = len(batch[1])
|
||||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||||
if gpc.config.model.sequence_parallel:
|
tensor_shape = torch.Size(
|
||||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||||
tensor_shape = torch.Size(
|
)
|
||||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tensor_shape = torch.Size(
|
|
||||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
|
||||||
)
|
|
||||||
|
|
||||||
with switch_evaluation_pipeline_scheduler(
|
with switch_evaluation_pipeline_scheduler(
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
|
|
@ -114,7 +108,6 @@ def evaluate_on_val_dls(
|
||||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||||
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
||||||
grad_accum_batch_size = data_cfg.micro_bsz
|
grad_accum_batch_size = data_cfg.micro_bsz
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
with switch_evaluation_no_pipeline_scheduler(
|
with switch_evaluation_no_pipeline_scheduler(
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
grad_accum_size=grad_accum_size,
|
grad_accum_size=grad_accum_size,
|
||||||
|
|
@ -170,4 +163,4 @@ def switch_sequence_parallel_mode():
|
||||||
gpc.config.model.sequence_parallel = False
|
gpc.config.model.sequence_parallel = False
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
gpc.config.model.sequence_parallel = prev_mode
|
gpc.config.model.sequence_parallel = prev_mode
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from enum import Enum
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -15,10 +16,22 @@ from internlm.solver.optimizer import HybridZeroOptimizer
|
||||||
from internlm.utils.common import get_current_device
|
from internlm.utils.common import get_current_device
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.utils.storage_manager import get_fns, llm_load, llm_save
|
from internlm.utils.storage_manager import (
|
||||||
|
get_fns,
|
||||||
|
get_storage_manager,
|
||||||
|
llm_load,
|
||||||
|
llm_save,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
quit_signal_handler = None
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointType(Enum):
|
||||||
|
NORMAL_CHECKPOINT = 1
|
||||||
|
SNAPSHOT_CHECKPOINT = 2
|
||||||
|
|
||||||
|
|
||||||
def get_model_topology(model):
|
def get_model_topology(model):
|
||||||
"""
|
"""
|
||||||
|
|
@ -289,3 +302,77 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info(f"reload load_scheduler:{lr_scheduler}")
|
logger.info(f"reload load_scheduler:{lr_scheduler}")
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointSaveManager:
|
||||||
|
"""StorageManagerContext"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ckpt_config,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
model_config,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
CheckpointSaveManager is used to decide when to store ckpt. If it is an asynchronous
|
||||||
|
upload mode, you must call wait_async_upload_finish at the end of the program to wait
|
||||||
|
for the asynchronous ckpt upload to complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt_config (dict): model checkpoint config.
|
||||||
|
model (nn.module): model obj
|
||||||
|
optimizer (object): optimzier obj.
|
||||||
|
lr_scheduler (object): lr_scheduler obj.
|
||||||
|
model_config (dict): model config.
|
||||||
|
"""
|
||||||
|
self.enable_save_ckpt = ckpt_config.enable_save_ckpt
|
||||||
|
self.checkpoint_every = ckpt_config.checkpoint_every
|
||||||
|
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
|
||||||
|
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
|
||||||
|
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
|
||||||
|
self.storage_manager = get_storage_manager()
|
||||||
|
self.snapshot_counter = 0
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
self.model_config = model_config
|
||||||
|
|
||||||
|
def try_save_checkpoint(self, train_state):
|
||||||
|
if not self.enable_save_ckpt:
|
||||||
|
return
|
||||||
|
|
||||||
|
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
|
||||||
|
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
|
||||||
|
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
|
||||||
|
if train_state.step_count % self.checkpoint_every == 0:
|
||||||
|
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
|
||||||
|
if save_ckpts is False:
|
||||||
|
if quit_signal_handler is not None:
|
||||||
|
save_ckpts, save_type = quit_signal_handler(train_state)
|
||||||
|
|
||||||
|
if save_ckpts:
|
||||||
|
# Wait for the previous round of asynchronous upload storage to complete.
|
||||||
|
self.storage_manager.wait()
|
||||||
|
if save_type == CheckpointType.SNAPSHOT_CHECKPOINT:
|
||||||
|
# Snapshot number, with only two snapshots written alternately.
|
||||||
|
self.snapshot_counter = (self.snapshot_counter + 1) % 2
|
||||||
|
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
|
||||||
|
else:
|
||||||
|
save_ckpt_folder = self.save_ckpt_folder
|
||||||
|
|
||||||
|
save_checkpoint(
|
||||||
|
folder=save_ckpt_folder,
|
||||||
|
model=self.model,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
scheduler=self.lr_scheduler,
|
||||||
|
train_state=train_state,
|
||||||
|
model_config=self.model_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def wait_async_upload_finish(self):
|
||||||
|
"""wait for all checkpoint uploads to be completed"""
|
||||||
|
self.storage_manager.wait()
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,26 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
from enum import Enum
|
import stat
|
||||||
from typing import Any, Dict, List, Union
|
from asyncio import InvalidStateError
|
||||||
|
from asyncio.tasks import ALL_COMPLETED
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, List, Union
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import botocore
|
import botocore
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.utils.common import SingletonMeta
|
from internlm.utils.common import SingletonMeta
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
|
|
@ -41,10 +49,6 @@ def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
|
||||||
storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
|
storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CheckpointType(Enum):
|
|
||||||
NORMAL_CHECKPOINT = 1
|
|
||||||
|
|
||||||
|
|
||||||
class StorageClient:
|
class StorageClient:
|
||||||
"""
|
"""
|
||||||
StorageClient as a client for s3 storage access.
|
StorageClient as a client for s3 storage access.
|
||||||
|
|
@ -54,7 +58,7 @@ class StorageClient:
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(client, load_path: str, map_location):
|
def load(client, load_path: str, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -71,25 +75,51 @@ class StorageClient:
|
||||||
|
|
||||||
|
|
||||||
class Boto3MetaInfo:
|
class Boto3MetaInfo:
|
||||||
def __init__(self, client: StorageClient, bucket_name: str, endpoint: str, file_path: str) -> None:
|
"""Boto3 meta info for save/load etc."""
|
||||||
self.client = client
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
is_async,
|
||||||
|
handler: StorageClient,
|
||||||
|
bucket_name: str,
|
||||||
|
endpoint: str,
|
||||||
|
file_path: str,
|
||||||
|
async_upload_fn: callable,
|
||||||
|
local_nvme_path=None,
|
||||||
|
) -> None:
|
||||||
|
self.is_async = is_async
|
||||||
|
self.client = handler
|
||||||
self.bucket_name = bucket_name
|
self.bucket_name = bucket_name
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
|
self.async_upload_fn = async_upload_fn
|
||||||
|
self.local_nvme_path = local_nvme_path
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
|
||||||
|
local_nvme_path: {self.local_nvme_path}"
|
||||||
|
|
||||||
|
|
||||||
class LocalMetaInfo:
|
class LocalMetaInfo:
|
||||||
def __init__(self, client: StorageClient, dest_path: str) -> None:
|
"""Local meta info for save/load etc."""
|
||||||
self.client = client
|
|
||||||
|
def __init__(self, handler: StorageClient, dest_path: str) -> None:
|
||||||
|
self.is_async = False
|
||||||
|
self.client = handler
|
||||||
self.dest_path = dest_path
|
self.dest_path = dest_path
|
||||||
|
self.async_upload_fn = None
|
||||||
|
|
||||||
|
|
||||||
def unpack_meta(meta):
|
def unpack_meta(meta):
|
||||||
args = []
|
args = []
|
||||||
|
is_async = meta.is_async
|
||||||
for k, v in meta.__dict__.items():
|
for k, v in meta.__dict__.items():
|
||||||
if k == "endpoint":
|
if k in ("endpoint", "async_upload_fn", "is_async"):
|
||||||
|
continue
|
||||||
|
if not is_async and k in ("local_nvme_path",):
|
||||||
continue
|
continue
|
||||||
args.append(v)
|
args.append(v)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -101,21 +131,6 @@ def compute_file_md5_by_chunk(file_name: str):
|
||||||
return hash_md5.hexdigest()
|
return hash_md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def get_boto3_meta(fp: str) -> Boto3MetaInfo:
|
|
||||||
assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url"
|
|
||||||
parts = fp.lstrip("s3://").split(os.path.sep)
|
|
||||||
match = boto3_url_re.match(parts[0])
|
|
||||||
assert match is not None, f"url '{fp}' is not a valid boto3 url"
|
|
||||||
bucket_name, endpoint = match.group(1), match.group(2)
|
|
||||||
endpoint = "http://" + endpoint + ":80"
|
|
||||||
return Boto3MetaInfo(None, bucket_name, endpoint, os.path.sep.join(parts[1:]))
|
|
||||||
|
|
||||||
|
|
||||||
def get_local_meta(fp: str) -> LocalMetaInfo:
|
|
||||||
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
|
|
||||||
return LocalMetaInfo(None, fp)
|
|
||||||
|
|
||||||
|
|
||||||
class Boto3Client(StorageClient):
|
class Boto3Client(StorageClient):
|
||||||
"""
|
"""
|
||||||
Boto3Client
|
Boto3Client
|
||||||
|
|
@ -169,7 +184,9 @@ class Boto3Client(StorageClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sync_upload_fileobj(handler, bucket_name: str, fp: str, *args, saved_obj=None, **kwargs):
|
def sync_upload_fileobj(
|
||||||
|
handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs
|
||||||
|
): # pylint: disable=W0613
|
||||||
assert saved_obj is not None, "saved_obj is None!"
|
assert saved_obj is not None, "saved_obj is None!"
|
||||||
try:
|
try:
|
||||||
with io.BytesIO() as f:
|
with io.BytesIO() as f:
|
||||||
|
|
@ -182,7 +199,14 @@ class Boto3Client(StorageClient):
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(handler, bucket_name: str, fp: str, *args, map_location="cpu", **kwargs) -> Dict:
|
def load(
|
||||||
|
handler,
|
||||||
|
bucket_name: str,
|
||||||
|
fp: str,
|
||||||
|
local_nvme_path: str, # pylint: disable=W0613
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
|
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
|
||||||
|
|
@ -191,7 +215,7 @@ class Boto3Client(StorageClient):
|
||||||
with io.BytesIO() as f:
|
with io.BytesIO() as f:
|
||||||
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
|
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
states = torch.load(f, *args, map_location=map_location, **kwargs)
|
states = torch.load(f, *args, **kwargs)
|
||||||
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
|
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
|
||||||
|
|
@ -199,15 +223,11 @@ class Boto3Client(StorageClient):
|
||||||
return states
|
return states
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def assert_fp_exists(
|
def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
|
||||||
handler,
|
|
||||||
bucket_name: str,
|
|
||||||
fp: str,
|
|
||||||
):
|
|
||||||
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
|
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fns(handler, bucket_name: str, fp: str):
|
def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613
|
||||||
"""
|
"""
|
||||||
Ref: https://stackoverflow.com/questions/54314563/
|
Ref: https://stackoverflow.com/questions/54314563/
|
||||||
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
|
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
|
||||||
|
|
@ -222,6 +242,22 @@ class Boto3Client(StorageClient):
|
||||||
folder_name_list.append(fp.rsplit("/", maxsplit=1)[1])
|
folder_name_list.append(fp.rsplit("/", maxsplit=1)[1])
|
||||||
return folder_name_list
|
return folder_name_list
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
|
||||||
|
try:
|
||||||
|
with open(local_nvme_path, "rb") as f:
|
||||||
|
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
|
||||||
|
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
|
||||||
|
) from exc
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_obj(handler, fp: str):
|
||||||
|
raise NotImplementedError("boto3 not support delete_obj")
|
||||||
|
|
||||||
|
|
||||||
class LocalClient(StorageClient):
|
class LocalClient(StorageClient):
|
||||||
"""
|
"""
|
||||||
|
|
@ -241,11 +277,11 @@ class LocalClient(StorageClient):
|
||||||
torch.save(saved_obj, fp, *args, **kwargs)
|
torch.save(saved_obj, fp, *args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(handler, fp: str, *args, map_location="cpu", **kwargs):
|
def load(handler, fp: str, *args, **kwargs): # pylint: disable=W0613
|
||||||
assert isinstance(handler, LocalClient)
|
assert isinstance(handler, LocalClient)
|
||||||
assert os.path.exists(fp), f"{fp} is not found!"
|
assert os.path.exists(fp), f"{fp} is not found!"
|
||||||
with open(fp, "rb") as f:
|
with open(fp, "rb") as f:
|
||||||
states = torch.load(f, map_location=map_location, *args, **kwargs)
|
states = torch.load(f, *args, **kwargs)
|
||||||
return states
|
return states
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -267,9 +303,77 @@ class LocalClient(StorageClient):
|
||||||
os.remove(fp)
|
os.remove(fp)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tmp_file_name(tmp_local_folder: str, fp: str):
|
||||||
|
"""
|
||||||
|
It should be noted that all our temporary files will be stored in the same folder,
|
||||||
|
so the file name passed upstream must be unique.
|
||||||
|
"""
|
||||||
|
base_path = os.path.join(tmp_local_folder, fp.split("/")[-1])
|
||||||
|
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
|
||||||
|
pid = os.getpid()
|
||||||
|
# step = self.step_counter
|
||||||
|
return "-".join([base_path, current_time, str(pid)]) + ".tmpfile" # , str(step)
|
||||||
|
|
||||||
|
|
||||||
|
def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaInfo:
|
||||||
|
assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url"
|
||||||
|
parts = fp.lstrip("s3://").split(os.path.sep)
|
||||||
|
match = boto3_url_re.match(parts[0])
|
||||||
|
assert match is not None, f"url '{fp}' is not a valid boto3 url"
|
||||||
|
bucket_name, endpoint = match.group(1), match.group(2)
|
||||||
|
endpoint = "http://" + endpoint + ":80"
|
||||||
|
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
|
||||||
|
return Boto3MetaInfo(
|
||||||
|
is_async=is_async,
|
||||||
|
handler=None,
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
endpoint=endpoint,
|
||||||
|
file_path=os.path.sep.join(parts[1:]),
|
||||||
|
async_upload_fn=Boto3Client.async_upload_fileobj,
|
||||||
|
local_nvme_path=tmp_step_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_meta(fp: str) -> LocalMetaInfo:
|
||||||
|
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
|
||||||
|
return LocalMetaInfo(None, fp)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mount_point_free_size(path: str):
|
||||||
|
"""
|
||||||
|
Returns the remaining space of the temporary storage mount point as a percentage.
|
||||||
|
Args:
|
||||||
|
path (str): temporary storage folder path.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the temporary storage folder does not exist,
|
||||||
|
an error will be reported。
|
||||||
|
"""
|
||||||
|
if os.path.exists(path):
|
||||||
|
st = os.statvfs(path)
|
||||||
|
# f_bavail: Number of free blocks for unprivileged users.
|
||||||
|
# f_bsize: Filesystem block size.
|
||||||
|
# return unit is TB.
|
||||||
|
return st.f_bavail * st.f_bsize / (1024**3)
|
||||||
|
|
||||||
|
|
||||||
|
def check_tmp_folder_accessibility(tmp_local_folder: str):
|
||||||
|
"""
|
||||||
|
Check access permissions for temporary storage.
|
||||||
|
"""
|
||||||
|
ret = True
|
||||||
|
if os.path.exists(tmp_local_folder):
|
||||||
|
ret &= os.access(tmp_local_folder, os.W_OK)
|
||||||
|
ret &= os.access(tmp_local_folder, os.R_OK)
|
||||||
|
if ret is False:
|
||||||
|
error_str = f'{socket.gethostname()} dose not have read and write permissions on {tmp_local_folder}"'
|
||||||
|
raise RuntimeError(error_str)
|
||||||
|
|
||||||
|
|
||||||
class StorageManager(metaclass=SingletonMeta):
|
class StorageManager(metaclass=SingletonMeta):
|
||||||
"""
|
"""
|
||||||
Storage Manager for saving or loading checkpoint.
|
Storage Manager for saving or loading checkpoint.
|
||||||
|
TODO: add a thread to poll the asynchronous storage state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BACKEND_TYPE = {"boto3", "local"}
|
BACKEND_TYPE = {"boto3", "local"}
|
||||||
|
|
@ -279,8 +383,39 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
}
|
}
|
||||||
CLI_DICT = {}
|
CLI_DICT = {}
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, enable_save, tmp_local_folde="/dev/shm/test/", async_mode=True, n_async_workers=8) -> None:
|
||||||
pass
|
self._exception_list = []
|
||||||
|
self._to_be_del_files = []
|
||||||
|
self._async_stack = []
|
||||||
|
self.upload_count = 0
|
||||||
|
self.tmp_local_folder = tmp_local_folde
|
||||||
|
self.async_mode = async_mode
|
||||||
|
self.has_warning = False
|
||||||
|
|
||||||
|
if enable_save and self.async_mode:
|
||||||
|
self._async_loop = asyncio.new_event_loop()
|
||||||
|
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers)
|
||||||
|
|
||||||
|
check_tmp_folder_accessibility(os.path.dirname(self.tmp_local_folder))
|
||||||
|
|
||||||
|
# Try to create tmp folder
|
||||||
|
try:
|
||||||
|
os.makedirs(self.tmp_local_folder, exist_ok=True)
|
||||||
|
os.chmod(self.tmp_local_folder, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
||||||
|
except FileExistsError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# In case it is a directory created by other users, we check the permissions again.
|
||||||
|
check_tmp_folder_accessibility(self.tmp_local_folder)
|
||||||
|
|
||||||
|
# Try to clean tmp folder's empty folder.
|
||||||
|
self.try_delete_tmpfile(self.tmp_local_folder)
|
||||||
|
|
||||||
|
# Avaliable storeage space check.
|
||||||
|
free_size = get_mount_point_free_size(self.tmp_local_folder)
|
||||||
|
if free_size < 0.1:
|
||||||
|
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
|
||||||
|
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
|
||||||
|
|
||||||
def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
|
def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -301,7 +436,7 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
meta_info = get_local_meta(path)
|
meta_info = get_local_meta(path)
|
||||||
backend_key = backend
|
backend_key = backend
|
||||||
elif backend == "boto3":
|
elif backend == "boto3":
|
||||||
meta_info = get_boto3_meta(path)
|
meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode)
|
||||||
backend_key = backend + ":" + meta_info.endpoint
|
backend_key = backend + ":" + meta_info.endpoint
|
||||||
init_args = (meta_info.endpoint,)
|
init_args = (meta_info.endpoint,)
|
||||||
if (
|
if (
|
||||||
|
|
@ -310,10 +445,12 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
or "HTTP_PROXY" in os.environ
|
or "HTTP_PROXY" in os.environ
|
||||||
or "HTTPS_PROXY" in os.environ
|
or "HTTPS_PROXY" in os.environ
|
||||||
):
|
):
|
||||||
raise RuntimeWarning(
|
if not self.has_warning:
|
||||||
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
|
logger.warning(
|
||||||
the proxy may make boto3 unavailable or affect performance."
|
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
|
||||||
)
|
the proxy may make boto3 unavailable or affect performance."
|
||||||
|
)
|
||||||
|
self.has_warning = True
|
||||||
|
|
||||||
assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}"
|
assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}"
|
||||||
|
|
||||||
|
|
@ -333,19 +470,137 @@ the proxy may make boto3 unavailable or affect performance."
|
||||||
meta = self._get_client(path=folder)
|
meta = self._get_client(path=folder)
|
||||||
return meta.client.get_fns(*unpack_meta(meta))
|
return meta.client.get_fns(*unpack_meta(meta))
|
||||||
|
|
||||||
def save(self, save_path: str, saved_obj: Any, *args, **kwargs):
|
def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs):
|
||||||
meta = self._get_client(path=save_path)
|
meta = self._get_client(path=save_path)
|
||||||
|
|
||||||
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
|
if async_upload is None:
|
||||||
|
async_upload = self.async_mode
|
||||||
def load(self, load_path: str, *args, map_location="cpu", **kwargs) -> Any:
|
if async_upload:
|
||||||
|
assert (
|
||||||
|
self.tmp_local_folder
|
||||||
|
), "StorageManager is not setted tmp_local_folder, so async save cannot be performed."
|
||||||
|
tmp_step_file = meta.local_nvme_path
|
||||||
|
self._to_be_del_files.append(tmp_step_file)
|
||||||
|
with open(tmp_step_file, "wb") as f:
|
||||||
|
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
|
||||||
|
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
||||||
|
else:
|
||||||
|
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
|
||||||
|
self.upload_count += 1
|
||||||
|
|
||||||
|
def load(self, load_path: str, *args, **kwargs) -> Any:
|
||||||
|
self.wait()
|
||||||
meta = self._get_client(path=load_path)
|
meta = self._get_client(path=load_path)
|
||||||
return meta.client.load(*unpack_meta(meta), map_location=map_location, *args, **kwargs)
|
return meta.client.load(*unpack_meta(meta), *args, **kwargs)
|
||||||
|
|
||||||
def delete_obj(self, fp: str):
|
def delete_obj(self, fp: str):
|
||||||
meta = self._get_client(path=fp)
|
meta = self._get_client(path=fp)
|
||||||
meta.client.delete_obj(*unpack_meta(meta))
|
meta.client.delete_obj(*unpack_meta(meta))
|
||||||
|
|
||||||
|
def _del_tmp_folder(self):
|
||||||
|
for fp in self._to_be_del_files:
|
||||||
|
try:
|
||||||
|
os.remove(fp)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
except SystemError as e:
|
||||||
|
logger.error(f'delete file: {fp}, failed for reason:"{e}"')
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
storage_manager = StorageManager()
|
def try_delete_tmpfile(self, tmp_dir: str):
|
||||||
|
"""Delete temporary files in tmp_dir."""
|
||||||
|
|
||||||
|
for filename in os.listdir(tmp_dir):
|
||||||
|
if filename.endswith(".tmpfile"):
|
||||||
|
file_path = os.path.join(tmp_dir, filename)
|
||||||
|
try:
|
||||||
|
os.remove(file_path)
|
||||||
|
logger.info(f"Delete tmpfile: {file_path}")
|
||||||
|
except OSError:
|
||||||
|
# Ignore deletion errors
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _sync_tasks(self) -> Awaitable[None]:
|
||||||
|
|
||||||
|
if not self._async_stack:
|
||||||
|
return
|
||||||
|
|
||||||
|
await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED)
|
||||||
|
|
||||||
|
for task in self._async_stack:
|
||||||
|
try:
|
||||||
|
task.exception()
|
||||||
|
except InvalidStateError:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
file_id = len(self._exception_list)
|
||||||
|
self._exception_list.append((e, file_id))
|
||||||
|
|
||||||
|
logger.error(f"File: {self._to_be_del_files[file_id]}, " f"upload failed with {e}")
|
||||||
|
|
||||||
|
self._async_stack.clear()
|
||||||
|
|
||||||
|
def async_executor(self, fn: Callable, *args, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Overview:
|
||||||
|
Execute task in background, then apppend the future instance in _async_stack.
|
||||||
|
Arguments:
|
||||||
|
- fn (:obj:`Callable`): Synchronization fuction.
|
||||||
|
"""
|
||||||
|
if not self._async_loop:
|
||||||
|
raise RuntimeError("Event loop was not initialized, please call this function in async or parallel mode")
|
||||||
|
t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs)
|
||||||
|
self._async_stack.append(t)
|
||||||
|
|
||||||
|
def wait(self) -> bool:
|
||||||
|
"""Wait for async operations to complete."""
|
||||||
|
|
||||||
|
if not self.async_mode:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._async_loop:
|
||||||
|
self._async_loop.run_until_complete(self._sync_tasks())
|
||||||
|
|
||||||
|
if self._exception_list:
|
||||||
|
for file_id, error_msg in self._exception_list:
|
||||||
|
logger.error(
|
||||||
|
f"Node:{socket.gethostname()}, Error: Checkpoint {self._to_be_del_files[file_id]} "
|
||||||
|
f"failed on step {self.upload_count}: {error_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Re-upload in sync mode
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to upload {self._to_be_del_files[file_id]} " f"on step {self.upload_count}: {error_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._del_tmp_folder()
|
||||||
|
self._exception_list.clear()
|
||||||
|
self._to_be_del_files.clear()
|
||||||
|
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.info("all async uploads succeeded!")
|
||||||
|
self.upload_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
storage_manager: StorageManager = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_storage_manager(ckpt_config):
|
||||||
|
global storage_manager
|
||||||
|
storage_manager = StorageManager(
|
||||||
|
ckpt_config.enable_save_ckpt,
|
||||||
|
tmp_local_folde=ckpt_config.async_upload_tmp_folder,
|
||||||
|
async_mode=ckpt_config.async_upload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_storage_manager():
|
||||||
|
assert storage_manager is not None, "storage_manager has not been init!"
|
||||||
|
return storage_manager
|
||||||
|
|
||||||
|
|
||||||
|
def wait_async_upload_finish():
|
||||||
|
dist.barrier()
|
||||||
|
storage_manager.wait()
|
||||||
|
|
|
||||||
80
train.py
80
train.py
|
|
@ -30,6 +30,8 @@ from internlm.data.packed_dataset import (
|
||||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||||
from internlm.model.loss import FlashGPTLMLoss
|
from internlm.model.loss import FlashGPTLMLoss
|
||||||
from internlm.model.metrics import AccPerplex
|
from internlm.model.metrics import AccPerplex
|
||||||
|
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
|
||||||
|
from internlm.monitor.monitor import monitor_manager as mm
|
||||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||||
|
|
@ -37,7 +39,6 @@ from internlm.utils.common import (
|
||||||
BatchSkipper,
|
BatchSkipper,
|
||||||
get_master_node,
|
get_master_node,
|
||||||
get_megatron_flops,
|
get_megatron_flops,
|
||||||
get_process_rank,
|
|
||||||
launch_time,
|
launch_time,
|
||||||
parse_args,
|
parse_args,
|
||||||
)
|
)
|
||||||
|
|
@ -45,12 +46,12 @@ from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_paral
|
||||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.utils.model_checkpoint import (
|
from internlm.utils.model_checkpoint import (
|
||||||
|
CheckpointSaveManager,
|
||||||
load_context,
|
load_context,
|
||||||
load_model_checkpoint,
|
load_model_checkpoint,
|
||||||
load_optimizer_checkpoint,
|
load_optimizer_checkpoint,
|
||||||
load_sampler,
|
load_sampler,
|
||||||
load_scheduler,
|
load_scheduler,
|
||||||
save_checkpoint,
|
|
||||||
)
|
)
|
||||||
from internlm.utils.parallel import (
|
from internlm.utils.parallel import (
|
||||||
get_parallel_log_file_name,
|
get_parallel_log_file_name,
|
||||||
|
|
@ -92,6 +93,15 @@ def initialize_distributed_env(config: str, launcher: str = "slurm", master_port
|
||||||
|
|
||||||
|
|
||||||
def initialize_llm_logger(start_time: str):
|
def initialize_llm_logger(start_time: str):
|
||||||
|
"""
|
||||||
|
Initialize customed uniscale logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_time (str): The launch time of current training job.
|
||||||
|
|
||||||
|
Returns: The instance of uniscale logger.
|
||||||
|
"""
|
||||||
|
|
||||||
uniscale_logger = initialize_uniscale_logger(
|
uniscale_logger = initialize_uniscale_logger(
|
||||||
job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
|
job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
|
||||||
)
|
)
|
||||||
|
|
@ -213,6 +223,8 @@ def get_train_data_loader(num_worker: int = 0):
|
||||||
|
|
||||||
|
|
||||||
def get_validation_data_loader(num_worker: int = 0):
|
def get_validation_data_loader(num_worker: int = 0):
|
||||||
|
"""Generate and return the validation data loader."""
|
||||||
|
|
||||||
data_cfg = gpc.config.data
|
data_cfg = gpc.config.data
|
||||||
|
|
||||||
if not data_cfg.valid_folder:
|
if not data_cfg.valid_folder:
|
||||||
|
|
@ -327,6 +339,8 @@ def record_current_batch_training_metrics(
|
||||||
Print some training metrics of current batch.
|
Print some training metrics of current batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
|
||||||
|
|
||||||
if success_update in (0, True):
|
if success_update in (0, True):
|
||||||
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
||||||
if is_no_pp_or_last_stage():
|
if is_no_pp_or_last_stage():
|
||||||
|
|
@ -405,12 +419,11 @@ def record_current_batch_training_metrics(
|
||||||
else:
|
else:
|
||||||
logger.info(line)
|
logger.info(line)
|
||||||
|
|
||||||
|
# if loss spike occurs, send alert info to feishu
|
||||||
|
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# initialize distributed environment
|
|
||||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
|
||||||
assert hasattr(gpc, "config") and gpc.config is not None
|
|
||||||
|
|
||||||
# init setting
|
# init setting
|
||||||
skip_batches = gpc.config.data.skip_batches
|
skip_batches = gpc.config.data.skip_batches
|
||||||
total_steps = gpc.config.data.total_steps
|
total_steps = gpc.config.data.total_steps
|
||||||
|
|
@ -419,11 +432,6 @@ def main(args):
|
||||||
label_smoothing = gpc.config.loss.label_smoothing
|
label_smoothing = gpc.config.loss.label_smoothing
|
||||||
lr = gpc.config.adam.lr
|
lr = gpc.config.adam.lr
|
||||||
|
|
||||||
# ckpt setting
|
|
||||||
save_ckpt_folder = gpc.config.ckpt.save_ckpt_folder
|
|
||||||
enable_save_ckpt = gpc.config.ckpt.enable_ckpt
|
|
||||||
checkpoint_every = gpc.config.ckpt.checkpoint_every
|
|
||||||
|
|
||||||
load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None)
|
load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None)
|
||||||
load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None)
|
load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None)
|
||||||
|
|
||||||
|
|
@ -477,8 +485,8 @@ def main(args):
|
||||||
model_load_path = load_model_only_folder
|
model_load_path = load_model_only_folder
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"===========New Run {current_time} on host:{socket.gethostname()},"
|
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
||||||
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||||
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -514,6 +522,14 @@ def main(args):
|
||||||
if load_optimizer:
|
if load_optimizer:
|
||||||
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
|
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
|
||||||
|
|
||||||
|
ckpt_save_manager = CheckpointSaveManager(
|
||||||
|
ckpt_config=gpc.config.ckpt,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
model_config=gpc.config.model,
|
||||||
|
)
|
||||||
|
|
||||||
# initialize metric for calculating accuracy and perplexity
|
# initialize metric for calculating accuracy and perplexity
|
||||||
metric = AccPerplex(
|
metric = AccPerplex(
|
||||||
device=torch.cuda.current_device(),
|
device=torch.cuda.current_device(),
|
||||||
|
|
@ -594,6 +610,9 @@ def main(args):
|
||||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||||
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||||
|
send_alert_message(
|
||||||
|
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
|
||||||
|
)
|
||||||
|
|
||||||
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
||||||
record_current_batch_training_metrics(
|
record_current_batch_training_metrics(
|
||||||
|
|
@ -629,26 +648,27 @@ def main(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||||
# save batch sampler that tracks the true consumed samples
|
# # save batch sampler that tracks the true consumed samples
|
||||||
if enable_save_ckpt and train_state.step_count % checkpoint_every == 0:
|
ckpt_save_manager.try_save_checkpoint(train_state)
|
||||||
save_checkpoint(
|
|
||||||
folder=save_ckpt_folder,
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
scheduler=lr_scheduler,
|
|
||||||
train_state=train_state,
|
|
||||||
model_config=gpc.config.model,
|
|
||||||
)
|
|
||||||
|
|
||||||
# wait for all checkpoint uploads to be completed
|
ckpt_save_manager.wait_async_upload_finish()
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
hostname = socket.gethostname()
|
||||||
|
|
||||||
try:
|
# initialize distributed environment
|
||||||
main(args)
|
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||||
except Exception:
|
assert hasattr(gpc, "config") and gpc.config is not None
|
||||||
print(f"Raise exception from {socket.gethostname()} with proc id: {get_process_rank()}")
|
|
||||||
traceback.print_exc()
|
# initialize monitor manager context
|
||||||
|
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
|
||||||
|
try:
|
||||||
|
main(args)
|
||||||
|
except Exception:
|
||||||
|
logger.error(
|
||||||
|
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
|
||||||
|
exc_info=traceback.format_exc(),
|
||||||
|
)
|
||||||
|
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue