diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 1e190d1..38247c1 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -7,22 +7,29 @@ MLP_RATIO = 8 / 3 NUM_LAYER = 32 VOCAB_SIZE = 103168 +MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" # Ckpt folder format: # fs: 'local:/mnt/nfs/XXX' -# oss: 'boto3:s3://model_weights/XXX' -MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" SAVE_CKPT_FOLDER = "local:llm_ckpts" LOAD_CKPT_FOLDER = "local:llm_ckpts/49" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" +CHECKPOINT_EVERY = 50 ckpt = dict( - # Path to save training ckpt. - save_ckpt_folder=SAVE_CKPT_FOLDER, - # Path to continue training ckpt (load model weights and scheduler/context states). - # load_ckpt_folder=LOAD_CKPT_FOLDER, - # Path to initialize with given model weights. - # load_model_only_folder=MODEL_ONLY_FOLDER, - checkpoint_every=50, - # Wheter to load optimizer states when continuing training. - load_optimizer=True, + enable_save_ckpt=False, # enable ckpt save. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states). + # load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights. + load_optimizer=True, # Wheter to load optimizer states when continuing training. + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. ) TRAIN_FOLDER = "/path/to/dataset" diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 1f60adc..33b5d15 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -11,6 +11,7 @@ import torch from internlm.core.context import Config from internlm.core.context import global_context as gpc from internlm.utils.logger import get_logger +from internlm.utils.storage_manager import init_storage_manager logger = get_logger(__file__) @@ -122,20 +123,44 @@ def args_sanity_check(): if "load_model_only_folder" not in gpc.config.ckpt: gpc.config.ckpt._add_item("load_model_only_folder", None) + if "async_upload" not in gpc.config.ckpt: + gpc.config.ckpt._add_item("async_upload", False) + else: + if gpc.config.ckpt.async_upload: + assert "save_ckpt_folder" in gpc.config.ckpt + if "boto3:" not in gpc.config.ckpt.save_ckpt_folder: + if gpc.is_rank_for_log(): + logger.warning( + "Storing ckpt on file system does not support asynchronous storage, will use sync save!" + ) + gpc.config.ckpt.async_upload = False + else: + if "async_upload_tmp_folder" not in gpc.config.ckpt: + gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/") + + if "snapshot_ckpt_folder" not in gpc.config.ckpt: + gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder), "snapshot") + + if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"): + gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2) + assert gpc.config.ckpt.oss_snapshot_freq > 0 + assert not ( gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None ), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time." - gpc.config.ckpt._add_item( - "enable_ckpt", gpc.config.ckpt.save_ckpt_folder is not None and gpc.config.ckpt.checkpoint_every > 0 - ) + if "enable_save_ckpt" not in gpc.config.ckpt: + gpc.config.ckpt._add_item("enable_save_ckpt", False) if gpc.is_rank_for_log(): logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201 - logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_ckpt}") + logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_save_ckpt}") logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}") logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}") + # initialization storage manager + init_storage_manager(gpc.config.ckpt) + # tensorboard writer config if "enable_tb" not in gpc.config: gpc.config._add_item("enable_tb", True) @@ -202,7 +227,13 @@ def args_sanity_check(): if "sequence_parallel" not in gpc.config.model: gpc.config.model._add_item("sequence_parallel", False) else: - assert not (gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False), "sequence parallel does not support use_flash_attn=False" + assert not ( + gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False + ), "sequence parallel does not support use_flash_attn=False" + + # feishu webhook address for alerting + if "alert_address" not in gpc.config: + gpc.config._add_item("alert_address", None) def launch( diff --git a/internlm/model/embedding.py b/internlm/model/embedding.py index 0951ccd..d35b9c1 100644 --- a/internlm/model/embedding.py +++ b/internlm/model/embedding.py @@ -55,10 +55,10 @@ class Embedding1D(nn.Module): output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1) - + if gpc.config.model.sequence_parallel: output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1) - + return output diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 2fa249c..50b4bf0 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -58,7 +58,11 @@ class ScaleColumnParallelLinear(nn.Linear): else: weight = self.weight return fused_dense_func_torch( - input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel + input, + weight, + self.bias, + process_group=self.process_group, + sequence_parallel=gpc.config.model.sequence_parallel, ) @@ -103,7 +107,11 @@ class RewardModelLinear(ScaleColumnParallelLinear): else: weight = self.weight return fused_dense_func_torch( - input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel + input, + weight, + self.bias, + process_group=self.process_group, + sequence_parallel=gpc.config.model.sequence_parallel, ) @@ -170,7 +178,13 @@ class FeedForward(nn.Module): dtype=dtype, ) self.w2 = ColumnParallelLinearTorch( - in_features, hidden_features, process_group, bias, sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype + in_features, + hidden_features, + process_group, + bias, + sequence_parallel=gpc.config.model.sequence_parallel, + device=device, + dtype=dtype, ) self.w3 = RowParallelLinearTorch( hidden_features, diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 3c58bd8..ee434f3 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -497,7 +497,7 @@ def build_model_with_cfg( use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, - sequence_parallel: bool = False, + sequence_parallel: bool = False, # pylint: disable=W0613 num_experts: int = 1, ): """ diff --git a/internlm/model/utils.py b/internlm/model/utils.py index a84f058..8b80af2 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -16,6 +16,9 @@ from torch.cuda.amp import custom_bwd from torch.distributed import ProcessGroup from internlm.core.context import global_context as gpc +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) def _split(input_, parallel_mode, dim=-1): @@ -84,6 +87,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function): def gather_forward_split_backward(input_, parallel_mode, dim): return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) + def linear_bias_wgrad_torch(input, grad_output, has_d_bias): assert input.dtype == grad_output.dtype grad_weight = torch.matmul(grad_output.t(), input) @@ -157,10 +161,11 @@ def fused_dense_func_torch( else: return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel) + class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. - + Args: input_: input matrix. parallel_mode: parallel mode. @@ -180,7 +185,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): return _gather(grad_output, ctx.mode, ctx.dim), None, None - + def split_forward_gather_backward(input_, parallel_mode, dim): return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) @@ -189,14 +194,14 @@ def split_forward_gather_backward(input_, parallel_mode, dim): def try_import_RMSNorm(): """ Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm - + """ try: from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm + return RMSNorm - except ModuleNotFoundError as e: - from internlm.utils.logger import get_logger - logger = get_logger(__file__) + except ModuleNotFoundError: logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") from internlm.model.norm import RMSNormTorch as RMSNorm + return RMSNorm diff --git a/internlm/monitor/__init__.py b/internlm/monitor/__init__.py new file mode 100644 index 0000000..b100cde --- /dev/null +++ b/internlm/monitor/__init__.py @@ -0,0 +1,4 @@ +from .monitor import initialize_monitor_manager, send_alert_message +from .utils import set_env_var + +__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"] diff --git a/internlm/monitor/alert.py b/internlm/monitor/alert.py new file mode 100644 index 0000000..78b6040 --- /dev/null +++ b/internlm/monitor/alert.py @@ -0,0 +1,53 @@ +import json +import time + +import requests + + +def send_feishu_msg_with_webhook(webhook: str, title: str, message: str): + """ + Use Feishu robot to send messages with the given webhook. + + Args: + webhook (str): The webhook to be used to send message. + title (str): The message title. + message (str): The message body. + + Returns: + The response from the request. Or catch the exception and return None. + + Raises: + Exception: An exception rasied by the HTTP post request. + + """ + + headers = {"Content-Type": "application/json;charset=utf-8"} + msg_body = { + "timestamp": int(time.time()), + "msg_type": "post", + "content": { + "post": { + "zh_cn": { + "title": title, + "content": [ + [ + { + "tag": "text", + "text": message, + }, + ], + ], + }, + }, + }, + } + + try: + res = requests.post(webhook, data=json.dumps(msg_body), headers=headers, timeout=30) + res = res.json() + print(f"Feishu webhook response: {res}") + except Exception as err: # pylint: disable=W0703 + print(f"HTTP Post error: {err}") + res = None + + return res diff --git a/internlm/monitor/monitor.py b/internlm/monitor/monitor.py new file mode 100644 index 0000000..ca5cf55 --- /dev/null +++ b/internlm/monitor/monitor.py @@ -0,0 +1,226 @@ +import os +import signal +import socket +import time +from contextlib import contextmanager +from threading import Thread + +from internlm.core.context import global_context as gpc +from internlm.monitor.alert import send_feishu_msg_with_webhook +from internlm.utils.common import SingletonMeta + +from .utils import get_job_key, set_env_var + + +def send_alert_message(address: str = None, title: str = None, message: str = None): + """ + Send alert messages to the given Feishu webhook address in log rank. + + Args: + address (str): The alert address to be used to send message, defaults to None. + title (str): The message title, defaults to None. + message (str): The message body, defaults to None. + """ + + if address is not None and gpc.is_rank_for_log(): + send_feishu_msg_with_webhook( + webhook=address, + title=title if title else get_job_key(), + message=message, + ) + + +class MonitorTracker(Thread): + """ + Track job status and alert to Feishu during job training. + + Args: + alert_address (str): The Feishu webhook address for sending alerting messages. + check_interval (float): The interval in seconds for monitoring checks. Defaults to 300. + loss_spike_limit (float): The threshold for detecting loss value spikes. Defaults to 1.5. + """ + + def __init__( + self, + alert_address: str, + check_interval: float = 300, + loss_spike_limit: float = 1.5, + ): + super().__init__() + self.alert_address = alert_address + self.check_interval = check_interval + self.loss_spike_limit = loss_spike_limit + self.last_active_time = -1 + self.last_loss_value = -1 + self.stopped = False + self.start() + + def run(self): + """ + start the monitor tracker. + """ + + while not self.stopped: + try: + self._check_stuck() + self._check_loss_spike() + except Exception: + continue + time.sleep(self.check_interval) + + def _check_stuck(self): + """ + Check training status for potential stuck condition. + """ + + new_active_time = -1 + if os.getenv("LAST_ACTIVE_TIMESTAMP") is not None: + new_active_time = os.getenv("LAST_ACTIVE_TIMESTAMP") + if int(new_active_time) <= int(self.last_active_time) and new_active_time != -1: + self._send_alert("Training may be in stuck status, please check it.") + self.last_active_time = new_active_time + + def _check_loss_spike(self): + """ + Check for loss value spikes. + """ + + if gpc.is_rank_for_log(): + new_loss_value = -1 + new_step_id = -1 + if os.getenv("LOSS") is not None: + new_loss_value = os.getenv("LOSS") + if os.getenv("STEP_ID") is not None: + new_step_id = os.getenv("STEP_ID") + + if (float(new_loss_value) / float(self.last_loss_value)) > self.loss_spike_limit and new_loss_value != -1: + assert int(new_step_id) >= 0 + self._send_alert( + f"Checking periodically: Loss spike may be happened in step {new_step_id}, " + f"loss value from {self.last_loss_value} to {new_loss_value}, please check it." + ) + + self.last_loss_value = new_loss_value + + def _send_alert(self, message): + """ + Send alerting message to the Feishu webhook address. + + Args: + message (str): The alerting message to be sent. + """ + + send_alert_message( + address=self.alert_address, + message=message, + ) + + def stop(self): + """ + Stop the monitor tracker. + """ + + self.stopped = True + + +class MonitorManager(metaclass=SingletonMeta): + """ + Monitor Manager for managing monitor thread and monitoring training status. + """ + + def __init__(self, loss_spike_limit: float = 1.5) -> None: + self.monitor_thread = None + self.loss_spike_limit = loss_spike_limit + self.last_step_loss = -1 + + def monitor_loss_spike(self, alert_address: str = None, step_count: int = 0, cur_step_loss: float = 0.0): + """Check loss value, if loss spike occurs, send alert message to Feishu.""" + set_env_var(key="LOSS", value=cur_step_loss) + set_env_var(key="STEP_ID", value=step_count) + + if self.last_step_loss != -1 and cur_step_loss > self.loss_spike_limit * self.last_step_loss: + send_alert_message( + address=alert_address, + message=( + f"Checking step by step: Loss spike may be happened in step {step_count}, " + f"loss value from {self.last_step_loss} to {cur_step_loss}, please check it." + ), + ) + self.last_step_loss = cur_step_loss + + def monitor_exception(self, alert_address: str = None, excp_info: str = None): + """Catch and format exception information, send alert message to Feishu.""" + filtered_trace = excp_info.split("\n")[-10:] + format_trace = "" + for line in filtered_trace: + format_trace += "\n" + line + send_alert_message( + address=alert_address, + message=f"Catch Exception from {socket.gethostname()} with rank id {gpc.get_global_rank()}:{format_trace}", + ) + + def handle_sigterm(self, alert_address: str = None): + """Catch SIGTERM signal, and send alert message to Feishu.""" + + def sigterm_handler(sys_signal, frame): + print("receive frame: ", frame) + print("receive signal: ", sys_signal) + send_alert_message( + address=alert_address, + message=f"Process received signal {signal} and exited.", + ) + + signal.signal(signal.SIGTERM, sigterm_handler) + + def start_monitor( + self, + job_name: str, + alert_address: str, + monitor_interval_seconds: int = 300, + loss_spike_limit: float = 1.5, + ): + """ + Initialize and start monitor thread for checking training job status, loss spike and so on. + + Args: + job_name (str): The training job name. + alert_address (str): The Feishu webhook address for sending alert messages. + monitor_interval_seconds (int): The time of monitor interval in seconds, defaults to 300. + loss_spike_limit (float): The limit multiple of current loss to previous loss value, which means loss spike + may be occurs, defaults to 1.5. + """ + + # initialize some variables for monitoring + set_env_var(key="JOB_NAME", value=job_name) + + # start a monitor thread, periodically check the training status + self.monitor_thread = MonitorTracker( + alert_address=alert_address, + check_interval=monitor_interval_seconds, + loss_spike_limit=loss_spike_limit, + ) + + def stop_monitor(self): + """Stop the monitor and alert thread.""" + if self.monitor_thread is not None: + self.monitor_thread.stop() + + +monitor_manager = MonitorManager() + + +@contextmanager +def initialize_monitor_manager(job_name: str = None, alert_address: str = None): + if alert_address is not None: + try: + monitor_manager.start_monitor(job_name=job_name, alert_address=alert_address) + monitor_manager.handle_sigterm(alert_address=alert_address) + send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.") + yield + finally: + send_alert_message( + address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed." + ) + monitor_manager.stop_monitor() + else: + yield diff --git a/internlm/monitor/utils.py b/internlm/monitor/utils.py new file mode 100644 index 0000000..f64c7dc --- /dev/null +++ b/internlm/monitor/utils.py @@ -0,0 +1,32 @@ +import os +from datetime import datetime + + +def now_time(): + return datetime.now().strftime("%b%d_%H-%M-%S") + + +def set_env_var(key, value): + os.environ[str(key)] = str(value) + + +def get_job_id(): + job_id = "none" + if os.getenv("SLURM_JOB_ID") is not None: + job_id = os.getenv("SLURM_JOB_ID") + elif os.getenv("K8S_WORKSPACE_ID") is not None: + job_id = os.getenv("K8S_WORKSPACE_ID") + + return job_id + + +def get_job_name(): + job_name = f"unknown-{now_time()}" + if os.getenv("JOB_NAME") is not None: + job_name = os.getenv("JOB_NAME") + + return job_name + + +def get_job_key(): + return f"{get_job_id()}_{get_job_name()}" diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 1cdb8f7..db77315 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -29,6 +29,7 @@ from internlm.solver.optimizer.utils import ( from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.monitor import send_alert_message from .utils import compute_norm @@ -543,6 +544,7 @@ class HybridZeroOptimizer(BaseOptimizer): if found_inf: if gpc.is_rank_for_log(): logger.warning("Overflow occurs, please check it.") + send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.") self._grad_store._averaged_gradients = dict() self.zero_grad() return False, None diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 584078f..d479284 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -34,18 +34,6 @@ def get_master_node(): return result -def get_process_rank(): - proc_rank = -1 - if os.getenv("SLURM_PROCID") is not None: - proc_rank = int(os.getenv("SLURM_PROCID")) - elif os.getenv("RANK") is not None: - # In k8s env, we use $RANK. - proc_rank = int(os.getenv("RANK")) - - # assert proc_rank != -1, "get_process_rank cant't get right process rank!" - return proc_rank - - def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: if torch.is_tensor(norm) and norm.device.type != "cuda": norm = norm.to(torch.cuda.current_device()) diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 8424e16..d10f0c1 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -6,8 +6,8 @@ from tqdm import tqdm from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.metrics import AccPerplex from internlm.core.scheduler import SchedulerMetricHook +from internlm.model.metrics import AccPerplex @contextmanager @@ -90,15 +90,9 @@ def evaluate_on_val_dls( total_val_bsz = len(batch[1]) assert total_val_bsz % data_cfg.micro_bsz == 0 num_microbatches = total_val_bsz // data_cfg.micro_bsz - if gpc.config.model.sequence_parallel: - sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) - tensor_shape = torch.Size( - [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE] - ) - else: - tensor_shape = torch.Size( - [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] - ) + tensor_shape = torch.Size( + [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] + ) with switch_evaluation_pipeline_scheduler( trainer=trainer, @@ -114,7 +108,6 @@ def evaluate_on_val_dls( assert total_val_bsz % data_cfg.micro_bsz == 0 grad_accum_size = total_val_bsz // data_cfg.micro_bsz grad_accum_batch_size = data_cfg.micro_bsz - # import pdb; pdb.set_trace() with switch_evaluation_no_pipeline_scheduler( trainer=trainer, grad_accum_size=grad_accum_size, @@ -170,4 +163,4 @@ def switch_sequence_parallel_mode(): gpc.config.model.sequence_parallel = False yield finally: - gpc.config.model.sequence_parallel = prev_mode \ No newline at end of file + gpc.config.model.sequence_parallel = prev_mode diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index ea0fad2..3fe29cc 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -4,6 +4,7 @@ import copy import os import time +from enum import Enum from typing import Dict import torch @@ -15,10 +16,22 @@ from internlm.solver.optimizer import HybridZeroOptimizer from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.storage_manager import get_fns, llm_load, llm_save +from internlm.utils.storage_manager import ( + get_fns, + get_storage_manager, + llm_load, + llm_save, +) logger = get_logger(__file__) +quit_signal_handler = None + + +class CheckpointType(Enum): + NORMAL_CHECKPOINT = 1 + SNAPSHOT_CHECKPOINT = 2 + def get_model_topology(model): """ @@ -289,3 +302,77 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train if gpc.is_rank_for_log(): logger.info(f"reload load_scheduler:{lr_scheduler}") + + +class CheckpointSaveManager: + """StorageManagerContext""" + + def __init__( + self, + ckpt_config, + model, + optimizer, + lr_scheduler, + model_config, + ) -> None: + """ + CheckpointSaveManager is used to decide when to store ckpt. If it is an asynchronous + upload mode, you must call wait_async_upload_finish at the end of the program to wait + for the asynchronous ckpt upload to complete. + + Args: + ckpt_config (dict): model checkpoint config. + model (nn.module): model obj + optimizer (object): optimzier obj. + lr_scheduler (object): lr_scheduler obj. + model_config (dict): model config. + """ + self.enable_save_ckpt = ckpt_config.enable_save_ckpt + self.checkpoint_every = ckpt_config.checkpoint_every + self.save_ckpt_folder = ckpt_config.save_ckpt_folder + self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder + self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq + self.storage_manager = get_storage_manager() + self.snapshot_counter = 0 + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.model_config = model_config + + def try_save_checkpoint(self, train_state): + if not self.enable_save_ckpt: + return + + save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT + if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: + save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT + if train_state.step_count % self.checkpoint_every == 0: + save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT + if save_ckpts is False: + if quit_signal_handler is not None: + save_ckpts, save_type = quit_signal_handler(train_state) + + if save_ckpts: + # Wait for the previous round of asynchronous upload storage to complete. + self.storage_manager.wait() + if save_type == CheckpointType.SNAPSHOT_CHECKPOINT: + # Snapshot number, with only two snapshots written alternately. + self.snapshot_counter = (self.snapshot_counter + 1) % 2 + save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}") + else: + save_ckpt_folder = self.save_ckpt_folder + + save_checkpoint( + folder=save_ckpt_folder, + model=self.model, + optimizer=self.optimizer, + scheduler=self.lr_scheduler, + train_state=train_state, + model_config=self.model_config, + ) + + def wait_async_upload_finish(self): + """wait for all checkpoint uploads to be completed""" + self.storage_manager.wait() + torch.distributed.barrier() diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 8bd7c88..481bd28 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -1,18 +1,26 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import asyncio +import concurrent.futures import hashlib import io import os +import pickle import re import socket -from enum import Enum -from typing import Any, Dict, List, Union +import stat +from asyncio import InvalidStateError +from asyncio.tasks import ALL_COMPLETED +from datetime import datetime +from typing import Any, Awaitable, Callable, Dict, List, Union import boto3 import botocore import torch +import torch.distributed as dist +from internlm.core.context import global_context as gpc from internlm.utils.common import SingletonMeta from internlm.utils.logger import get_logger @@ -41,10 +49,6 @@ def llm_save(save_path: str, saved_obj: Any, *args, **kwargs): storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs) -class CheckpointType(Enum): - NORMAL_CHECKPOINT = 1 - - class StorageClient: """ StorageClient as a client for s3 storage access. @@ -54,7 +58,7 @@ class StorageClient: self.handler = handler @staticmethod - def load(client, load_path: str, map_location): + def load(client, load_path: str, *args, **kwargs): raise NotImplementedError @staticmethod @@ -71,25 +75,51 @@ class StorageClient: class Boto3MetaInfo: - def __init__(self, client: StorageClient, bucket_name: str, endpoint: str, file_path: str) -> None: - self.client = client + """Boto3 meta info for save/load etc.""" + + def __init__( + self, + is_async, + handler: StorageClient, + bucket_name: str, + endpoint: str, + file_path: str, + async_upload_fn: callable, + local_nvme_path=None, + ) -> None: + self.is_async = is_async + self.client = handler self.bucket_name = bucket_name self.endpoint = endpoint self.file_path = file_path + self.async_upload_fn = async_upload_fn + self.local_nvme_path = local_nvme_path + + def __str__(self) -> str: + return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \ +local_nvme_path: {self.local_nvme_path}" class LocalMetaInfo: - def __init__(self, client: StorageClient, dest_path: str) -> None: - self.client = client + """Local meta info for save/load etc.""" + + def __init__(self, handler: StorageClient, dest_path: str) -> None: + self.is_async = False + self.client = handler self.dest_path = dest_path + self.async_upload_fn = None def unpack_meta(meta): args = [] + is_async = meta.is_async for k, v in meta.__dict__.items(): - if k == "endpoint": + if k in ("endpoint", "async_upload_fn", "is_async"): + continue + if not is_async and k in ("local_nvme_path",): continue args.append(v) + return args @@ -101,21 +131,6 @@ def compute_file_md5_by_chunk(file_name: str): return hash_md5.hexdigest() -def get_boto3_meta(fp: str) -> Boto3MetaInfo: - assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url" - parts = fp.lstrip("s3://").split(os.path.sep) - match = boto3_url_re.match(parts[0]) - assert match is not None, f"url '{fp}' is not a valid boto3 url" - bucket_name, endpoint = match.group(1), match.group(2) - endpoint = "http://" + endpoint + ":80" - return Boto3MetaInfo(None, bucket_name, endpoint, os.path.sep.join(parts[1:])) - - -def get_local_meta(fp: str) -> LocalMetaInfo: - assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path" - return LocalMetaInfo(None, fp) - - class Boto3Client(StorageClient): """ Boto3Client @@ -169,7 +184,9 @@ class Boto3Client(StorageClient): ) @staticmethod - def sync_upload_fileobj(handler, bucket_name: str, fp: str, *args, saved_obj=None, **kwargs): + def sync_upload_fileobj( + handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs + ): # pylint: disable=W0613 assert saved_obj is not None, "saved_obj is None!" try: with io.BytesIO() as f: @@ -182,7 +199,14 @@ class Boto3Client(StorageClient): ) from exc @staticmethod - def load(handler, bucket_name: str, fp: str, *args, map_location="cpu", **kwargs) -> Dict: + def load( + handler, + bucket_name: str, + fp: str, + local_nvme_path: str, # pylint: disable=W0613 + *args, + **kwargs, + ) -> Dict: """ Args: fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt @@ -191,7 +215,7 @@ class Boto3Client(StorageClient): with io.BytesIO() as f: handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config) f.seek(0) - states = torch.load(f, *args, map_location=map_location, **kwargs) + states = torch.load(f, *args, **kwargs) except handler.botocore.exceptions.EndpointConnectionError as exc: raise RuntimeError( f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}" @@ -199,15 +223,11 @@ class Boto3Client(StorageClient): return states @staticmethod - def assert_fp_exists( - handler, - bucket_name: str, - fp: str, - ): + def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613 assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp @staticmethod - def get_fns(handler, bucket_name: str, fp: str): + def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613 """ Ref: https://stackoverflow.com/questions/54314563/ how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2 @@ -222,6 +242,22 @@ class Boto3Client(StorageClient): folder_name_list.append(fp.rsplit("/", maxsplit=1)[1]) return folder_name_list + @staticmethod + def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str): + try: + with open(local_nvme_path, "rb") as f: + handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config) + except handler.botocore.exceptions.EndpointConnectionError as exc: + raise RuntimeError( + f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}" + ) from exc + except Exception as e: + raise e + + @staticmethod + def delete_obj(handler, fp: str): + raise NotImplementedError("boto3 not support delete_obj") + class LocalClient(StorageClient): """ @@ -241,11 +277,11 @@ class LocalClient(StorageClient): torch.save(saved_obj, fp, *args, **kwargs) @staticmethod - def load(handler, fp: str, *args, map_location="cpu", **kwargs): + def load(handler, fp: str, *args, **kwargs): # pylint: disable=W0613 assert isinstance(handler, LocalClient) assert os.path.exists(fp), f"{fp} is not found!" with open(fp, "rb") as f: - states = torch.load(f, map_location=map_location, *args, **kwargs) + states = torch.load(f, *args, **kwargs) return states @staticmethod @@ -267,9 +303,77 @@ class LocalClient(StorageClient): os.remove(fp) +def get_tmp_file_name(tmp_local_folder: str, fp: str): + """ + It should be noted that all our temporary files will be stored in the same folder, + so the file name passed upstream must be unique. + """ + base_path = os.path.join(tmp_local_folder, fp.split("/")[-1]) + current_time = datetime.now().strftime("%b%d_%H-%M-%S") + pid = os.getpid() + # step = self.step_counter + return "-".join([base_path, current_time, str(pid)]) + ".tmpfile" # , str(step) + + +def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaInfo: + assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url" + parts = fp.lstrip("s3://").split(os.path.sep) + match = boto3_url_re.match(parts[0]) + assert match is not None, f"url '{fp}' is not a valid boto3 url" + bucket_name, endpoint = match.group(1), match.group(2) + endpoint = "http://" + endpoint + ":80" + tmp_step_file = get_tmp_file_name(tmp_local_folder, fp) + return Boto3MetaInfo( + is_async=is_async, + handler=None, + bucket_name=bucket_name, + endpoint=endpoint, + file_path=os.path.sep.join(parts[1:]), + async_upload_fn=Boto3Client.async_upload_fileobj, + local_nvme_path=tmp_step_file, + ) + + +def get_local_meta(fp: str) -> LocalMetaInfo: + assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path" + return LocalMetaInfo(None, fp) + + +def get_mount_point_free_size(path: str): + """ + Returns the remaining space of the temporary storage mount point as a percentage. + Args: + path (str): temporary storage folder path. + + Raises: + FileNotFoundError: If the temporary storage folder does not exist, + an error will be reported。 + """ + if os.path.exists(path): + st = os.statvfs(path) + # f_bavail: Number of free blocks for unprivileged users. + # f_bsize: Filesystem block size. + # return unit is TB. + return st.f_bavail * st.f_bsize / (1024**3) + + +def check_tmp_folder_accessibility(tmp_local_folder: str): + """ + Check access permissions for temporary storage. + """ + ret = True + if os.path.exists(tmp_local_folder): + ret &= os.access(tmp_local_folder, os.W_OK) + ret &= os.access(tmp_local_folder, os.R_OK) + if ret is False: + error_str = f'{socket.gethostname()} dose not have read and write permissions on {tmp_local_folder}"' + raise RuntimeError(error_str) + + class StorageManager(metaclass=SingletonMeta): """ Storage Manager for saving or loading checkpoint. + TODO: add a thread to poll the asynchronous storage state. """ BACKEND_TYPE = {"boto3", "local"} @@ -279,8 +383,39 @@ class StorageManager(metaclass=SingletonMeta): } CLI_DICT = {} - def __init__(self) -> None: - pass + def __init__(self, enable_save, tmp_local_folde="/dev/shm/test/", async_mode=True, n_async_workers=8) -> None: + self._exception_list = [] + self._to_be_del_files = [] + self._async_stack = [] + self.upload_count = 0 + self.tmp_local_folder = tmp_local_folde + self.async_mode = async_mode + self.has_warning = False + + if enable_save and self.async_mode: + self._async_loop = asyncio.new_event_loop() + self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers) + + check_tmp_folder_accessibility(os.path.dirname(self.tmp_local_folder)) + + # Try to create tmp folder + try: + os.makedirs(self.tmp_local_folder, exist_ok=True) + os.chmod(self.tmp_local_folder, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) + except FileExistsError: + pass + + # In case it is a directory created by other users, we check the permissions again. + check_tmp_folder_accessibility(self.tmp_local_folder) + + # Try to clean tmp folder's empty folder. + self.try_delete_tmpfile(self.tmp_local_folder) + + # Avaliable storeage space check. + free_size = get_mount_point_free_size(self.tmp_local_folder) + if free_size < 0.1: + logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!') + raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}") def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]: """ @@ -301,7 +436,7 @@ class StorageManager(metaclass=SingletonMeta): meta_info = get_local_meta(path) backend_key = backend elif backend == "boto3": - meta_info = get_boto3_meta(path) + meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode) backend_key = backend + ":" + meta_info.endpoint init_args = (meta_info.endpoint,) if ( @@ -310,10 +445,12 @@ class StorageManager(metaclass=SingletonMeta): or "HTTP_PROXY" in os.environ or "HTTPS_PROXY" in os.environ ): - raise RuntimeWarning( - "HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \ -the proxy may make boto3 unavailable or affect performance." - ) + if not self.has_warning: + logger.warning( + "HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \ + the proxy may make boto3 unavailable or affect performance." + ) + self.has_warning = True assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}" @@ -333,19 +470,137 @@ the proxy may make boto3 unavailable or affect performance." meta = self._get_client(path=folder) return meta.client.get_fns(*unpack_meta(meta)) - def save(self, save_path: str, saved_obj: Any, *args, **kwargs): + def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs): meta = self._get_client(path=save_path) - meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs) - - def load(self, load_path: str, *args, map_location="cpu", **kwargs) -> Any: + if async_upload is None: + async_upload = self.async_mode + if async_upload: + assert ( + self.tmp_local_folder + ), "StorageManager is not setted tmp_local_folder, so async save cannot be performed." + tmp_step_file = meta.local_nvme_path + self._to_be_del_files.append(tmp_step_file) + with open(tmp_step_file, "wb") as f: + torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) + self.async_executor(meta.async_upload_fn, *unpack_meta(meta)) + os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) + else: + meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs) + self.upload_count += 1 + def load(self, load_path: str, *args, **kwargs) -> Any: + self.wait() meta = self._get_client(path=load_path) - return meta.client.load(*unpack_meta(meta), map_location=map_location, *args, **kwargs) + return meta.client.load(*unpack_meta(meta), *args, **kwargs) def delete_obj(self, fp: str): meta = self._get_client(path=fp) meta.client.delete_obj(*unpack_meta(meta)) + def _del_tmp_folder(self): + for fp in self._to_be_del_files: + try: + os.remove(fp) + except FileNotFoundError: + pass + except SystemError as e: + logger.error(f'delete file: {fp}, failed for reason:"{e}"') + else: + pass -storage_manager = StorageManager() + def try_delete_tmpfile(self, tmp_dir: str): + """Delete temporary files in tmp_dir.""" + + for filename in os.listdir(tmp_dir): + if filename.endswith(".tmpfile"): + file_path = os.path.join(tmp_dir, filename) + try: + os.remove(file_path) + logger.info(f"Delete tmpfile: {file_path}") + except OSError: + # Ignore deletion errors + pass + + async def _sync_tasks(self) -> Awaitable[None]: + + if not self._async_stack: + return + + await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED) + + for task in self._async_stack: + try: + task.exception() + except InvalidStateError: + continue + except Exception as e: + file_id = len(self._exception_list) + self._exception_list.append((e, file_id)) + + logger.error(f"File: {self._to_be_del_files[file_id]}, " f"upload failed with {e}") + + self._async_stack.clear() + + def async_executor(self, fn: Callable, *args, **kwargs) -> None: + """ + Overview: + Execute task in background, then apppend the future instance in _async_stack. + Arguments: + - fn (:obj:`Callable`): Synchronization fuction. + """ + if not self._async_loop: + raise RuntimeError("Event loop was not initialized, please call this function in async or parallel mode") + t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs) + self._async_stack.append(t) + + def wait(self) -> bool: + """Wait for async operations to complete.""" + + if not self.async_mode: + return + + if self._async_loop: + self._async_loop.run_until_complete(self._sync_tasks()) + + if self._exception_list: + for file_id, error_msg in self._exception_list: + logger.error( + f"Node:{socket.gethostname()}, Error: Checkpoint {self._to_be_del_files[file_id]} " + f"failed on step {self.upload_count}: {error_msg}" + ) + + # TODO: Re-upload in sync mode + raise RuntimeError( + f"Failed to upload {self._to_be_del_files[file_id]} " f"on step {self.upload_count}: {error_msg}" + ) + + self._del_tmp_folder() + self._exception_list.clear() + self._to_be_del_files.clear() + + if gpc.is_rank_for_log(): + logger.info("all async uploads succeeded!") + self.upload_count += 1 + + +storage_manager: StorageManager = None + + +def init_storage_manager(ckpt_config): + global storage_manager + storage_manager = StorageManager( + ckpt_config.enable_save_ckpt, + tmp_local_folde=ckpt_config.async_upload_tmp_folder, + async_mode=ckpt_config.async_upload, + ) + + +def get_storage_manager(): + assert storage_manager is not None, "storage_manager has not been init!" + return storage_manager + + +def wait_async_upload_finish(): + dist.barrier() + storage_manager.wait() diff --git a/train.py b/train.py index 675cc77..72f2820 100644 --- a/train.py +++ b/train.py @@ -30,6 +30,8 @@ from internlm.data.packed_dataset import ( from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.model.loss import FlashGPTLMLoss from internlm.model.metrics import AccPerplex +from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var +from internlm.monitor.monitor import monitor_manager as mm from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.optimizer import HybridZeroOptimizer @@ -37,7 +39,6 @@ from internlm.utils.common import ( BatchSkipper, get_master_node, get_megatron_flops, - get_process_rank, launch_time, parse_args, ) @@ -45,12 +46,12 @@ from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_paral from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.model_checkpoint import ( + CheckpointSaveManager, load_context, load_model_checkpoint, load_optimizer_checkpoint, load_sampler, load_scheduler, - save_checkpoint, ) from internlm.utils.parallel import ( get_parallel_log_file_name, @@ -92,6 +93,15 @@ def initialize_distributed_env(config: str, launcher: str = "slurm", master_port def initialize_llm_logger(start_time: str): + """ + Initialize customed uniscale logger. + + Args: + start_time (str): The launch time of current training job. + + Returns: The instance of uniscale logger. + """ + uniscale_logger = initialize_uniscale_logger( job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name() ) @@ -213,6 +223,8 @@ def get_train_data_loader(num_worker: int = 0): def get_validation_data_loader(num_worker: int = 0): + """Generate and return the validation data loader.""" + data_cfg = gpc.config.data if not data_cfg.valid_folder: @@ -327,6 +339,8 @@ def record_current_batch_training_metrics( Print some training metrics of current batch. """ + set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time())) + if success_update in (0, True): train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) if is_no_pp_or_last_stage(): @@ -405,12 +419,11 @@ def record_current_batch_training_metrics( else: logger.info(line) + # if loss spike occurs, send alert info to feishu + mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item()) + def main(args): - # initialize distributed environment - initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) - assert hasattr(gpc, "config") and gpc.config is not None - # init setting skip_batches = gpc.config.data.skip_batches total_steps = gpc.config.data.total_steps @@ -419,11 +432,6 @@ def main(args): label_smoothing = gpc.config.loss.label_smoothing lr = gpc.config.adam.lr - # ckpt setting - save_ckpt_folder = gpc.config.ckpt.save_ckpt_folder - enable_save_ckpt = gpc.config.ckpt.enable_ckpt - checkpoint_every = gpc.config.ckpt.checkpoint_every - load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None) load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None) @@ -477,8 +485,8 @@ def main(args): model_load_path = load_model_only_folder else: logger.info( - f"===========New Run {current_time} on host:{socket.gethostname()}," - f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," + f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," + f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" ) @@ -514,6 +522,14 @@ def main(args): if load_optimizer: load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer) + ckpt_save_manager = CheckpointSaveManager( + ckpt_config=gpc.config.ckpt, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model_config=gpc.config.model, + ) + # initialize metric for calculating accuracy and perplexity metric = AccPerplex( device=torch.cuda.current_device(), @@ -594,6 +610,9 @@ def main(args): train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case logger.warning(f"Warning: skip parameter update at step {batch_count}.") + send_alert_message( + address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}." + ) # calculate and record the training metrics, eg. loss, accuracy and so on. record_current_batch_training_metrics( @@ -629,26 +648,27 @@ def main(args): ) # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" - # save batch sampler that tracks the true consumed samples - if enable_save_ckpt and train_state.step_count % checkpoint_every == 0: - save_checkpoint( - folder=save_ckpt_folder, - model=model, - optimizer=optimizer, - scheduler=lr_scheduler, - train_state=train_state, - model_config=gpc.config.model, - ) + # # save batch sampler that tracks the true consumed samples + ckpt_save_manager.try_save_checkpoint(train_state) - # wait for all checkpoint uploads to be completed - dist.barrier() + ckpt_save_manager.wait_async_upload_finish() if __name__ == "__main__": args = parse_args() + hostname = socket.gethostname() - try: - main(args) - except Exception: - print(f"Raise exception from {socket.gethostname()} with proc id: {get_process_rank()}") - traceback.print_exc() + # initialize distributed environment + initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) + assert hasattr(gpc, "config") and gpc.config is not None + + # initialize monitor manager context + with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address): + try: + main(args) + except Exception: + logger.error( + f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}", + exc_info=traceback.format_exc(), + ) + mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())