Merge branch 'develop' into feature_add_moe

pull/375/head
Wenwen Qu 2023-08-08 16:51:10 +08:00 committed by GitHub
commit 9ad7942568
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 851 additions and 134 deletions

View File

@ -7,22 +7,29 @@ MLP_RATIO = 8 / 3
NUM_LAYER = 32
VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
# oss: 'boto3:s3://model_weights/XXX'
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
# Path to save training ckpt.
save_ckpt_folder=SAVE_CKPT_FOLDER,
# Path to continue training ckpt (load model weights and scheduler/context states).
# load_ckpt_folder=LOAD_CKPT_FOLDER,
# Path to initialize with given model weights.
# load_model_only_folder=MODEL_ONLY_FOLDER,
checkpoint_every=50,
# Wheter to load optimizer states when continuing training.
load_optimizer=True,
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
load_optimizer=True, # Wheter to load optimizer states when continuing training.
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)
TRAIN_FOLDER = "/path/to/dataset"

View File

@ -11,6 +11,7 @@ import torch
from internlm.core.context import Config
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import init_storage_manager
logger = get_logger(__file__)
@ -122,20 +123,44 @@ def args_sanity_check():
if "load_model_only_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("load_model_only_folder", None)
if "async_upload" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("async_upload", False)
else:
if gpc.config.ckpt.async_upload:
assert "save_ckpt_folder" in gpc.config.ckpt
if "boto3:" not in gpc.config.ckpt.save_ckpt_folder:
if gpc.is_rank_for_log():
logger.warning(
"Storing ckpt on file system does not support asynchronous storage, will use sync save!"
)
gpc.config.ckpt.async_upload = False
else:
if "async_upload_tmp_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
if "snapshot_ckpt_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder), "snapshot")
if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"):
gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2)
assert gpc.config.ckpt.oss_snapshot_freq > 0
assert not (
gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None
), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time."
gpc.config.ckpt._add_item(
"enable_ckpt", gpc.config.ckpt.save_ckpt_folder is not None and gpc.config.ckpt.checkpoint_every > 0
)
if "enable_save_ckpt" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("enable_save_ckpt", False)
if gpc.is_rank_for_log():
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_ckpt}")
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_save_ckpt}")
logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}")
logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}")
# initialization storage manager
init_storage_manager(gpc.config.ckpt)
# tensorboard writer config
if "enable_tb" not in gpc.config:
gpc.config._add_item("enable_tb", True)
@ -202,7 +227,13 @@ def args_sanity_check():
if "sequence_parallel" not in gpc.config.model:
gpc.config.model._add_item("sequence_parallel", False)
else:
assert not (gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False), "sequence parallel does not support use_flash_attn=False"
assert not (
gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False
), "sequence parallel does not support use_flash_attn=False"
# feishu webhook address for alerting
if "alert_address" not in gpc.config:
gpc.config._add_item("alert_address", None)
def launch(

View File

@ -58,7 +58,11 @@ class ScaleColumnParallelLinear(nn.Linear):
else:
weight = self.weight
return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
input,
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.model.sequence_parallel,
)
@ -103,7 +107,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
else:
weight = self.weight
return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
input,
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.model.sequence_parallel,
)
@ -170,7 +178,13 @@ class FeedForward(nn.Module):
dtype=dtype,
)
self.w2 = ColumnParallelLinearTorch(
in_features, hidden_features, process_group, bias, sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.model.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = RowParallelLinearTorch(
hidden_features,

View File

@ -497,7 +497,7 @@ def build_model_with_cfg(
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
sequence_parallel: bool = False,
sequence_parallel: bool = False, # pylint: disable=W0613
num_experts: int = 1,
):
"""

View File

@ -16,6 +16,9 @@ from torch.cuda.amp import custom_bwd
from torch.distributed import ProcessGroup
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
def _split(input_, parallel_mode, dim=-1):
@ -84,6 +87,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
assert input.dtype == grad_output.dtype
grad_weight = torch.matmul(grad_output.t(), input)
@ -157,6 +161,7 @@ def fused_dense_func_torch(
else:
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
@ -193,10 +198,10 @@ def try_import_RMSNorm():
"""
try:
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
return RMSNorm
except ModuleNotFoundError as e:
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
except ModuleNotFoundError:
logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
from internlm.model.norm import RMSNormTorch as RMSNorm
return RMSNorm

View File

@ -0,0 +1,4 @@
from .monitor import initialize_monitor_manager, send_alert_message
from .utils import set_env_var
__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"]

53
internlm/monitor/alert.py Normal file
View File

@ -0,0 +1,53 @@
import json
import time
import requests
def send_feishu_msg_with_webhook(webhook: str, title: str, message: str):
"""
Use Feishu robot to send messages with the given webhook.
Args:
webhook (str): The webhook to be used to send message.
title (str): The message title.
message (str): The message body.
Returns:
The response from the request. Or catch the exception and return None.
Raises:
Exception: An exception rasied by the HTTP post request.
"""
headers = {"Content-Type": "application/json;charset=utf-8"}
msg_body = {
"timestamp": int(time.time()),
"msg_type": "post",
"content": {
"post": {
"zh_cn": {
"title": title,
"content": [
[
{
"tag": "text",
"text": message,
},
],
],
},
},
},
}
try:
res = requests.post(webhook, data=json.dumps(msg_body), headers=headers, timeout=30)
res = res.json()
print(f"Feishu webhook response: {res}")
except Exception as err: # pylint: disable=W0703
print(f"HTTP Post error: {err}")
res = None
return res

226
internlm/monitor/monitor.py Normal file
View File

@ -0,0 +1,226 @@
import os
import signal
import socket
import time
from contextlib import contextmanager
from threading import Thread
from internlm.core.context import global_context as gpc
from internlm.monitor.alert import send_feishu_msg_with_webhook
from internlm.utils.common import SingletonMeta
from .utils import get_job_key, set_env_var
def send_alert_message(address: str = None, title: str = None, message: str = None):
"""
Send alert messages to the given Feishu webhook address in log rank.
Args:
address (str): The alert address to be used to send message, defaults to None.
title (str): The message title, defaults to None.
message (str): The message body, defaults to None.
"""
if address is not None and gpc.is_rank_for_log():
send_feishu_msg_with_webhook(
webhook=address,
title=title if title else get_job_key(),
message=message,
)
class MonitorTracker(Thread):
"""
Track job status and alert to Feishu during job training.
Args:
alert_address (str): The Feishu webhook address for sending alerting messages.
check_interval (float): The interval in seconds for monitoring checks. Defaults to 300.
loss_spike_limit (float): The threshold for detecting loss value spikes. Defaults to 1.5.
"""
def __init__(
self,
alert_address: str,
check_interval: float = 300,
loss_spike_limit: float = 1.5,
):
super().__init__()
self.alert_address = alert_address
self.check_interval = check_interval
self.loss_spike_limit = loss_spike_limit
self.last_active_time = -1
self.last_loss_value = -1
self.stopped = False
self.start()
def run(self):
"""
start the monitor tracker.
"""
while not self.stopped:
try:
self._check_stuck()
self._check_loss_spike()
except Exception:
continue
time.sleep(self.check_interval)
def _check_stuck(self):
"""
Check training status for potential stuck condition.
"""
new_active_time = -1
if os.getenv("LAST_ACTIVE_TIMESTAMP") is not None:
new_active_time = os.getenv("LAST_ACTIVE_TIMESTAMP")
if int(new_active_time) <= int(self.last_active_time) and new_active_time != -1:
self._send_alert("Training may be in stuck status, please check it.")
self.last_active_time = new_active_time
def _check_loss_spike(self):
"""
Check for loss value spikes.
"""
if gpc.is_rank_for_log():
new_loss_value = -1
new_step_id = -1
if os.getenv("LOSS") is not None:
new_loss_value = os.getenv("LOSS")
if os.getenv("STEP_ID") is not None:
new_step_id = os.getenv("STEP_ID")
if (float(new_loss_value) / float(self.last_loss_value)) > self.loss_spike_limit and new_loss_value != -1:
assert int(new_step_id) >= 0
self._send_alert(
f"Checking periodically: Loss spike may be happened in step {new_step_id}, "
f"loss value from {self.last_loss_value} to {new_loss_value}, please check it."
)
self.last_loss_value = new_loss_value
def _send_alert(self, message):
"""
Send alerting message to the Feishu webhook address.
Args:
message (str): The alerting message to be sent.
"""
send_alert_message(
address=self.alert_address,
message=message,
)
def stop(self):
"""
Stop the monitor tracker.
"""
self.stopped = True
class MonitorManager(metaclass=SingletonMeta):
"""
Monitor Manager for managing monitor thread and monitoring training status.
"""
def __init__(self, loss_spike_limit: float = 1.5) -> None:
self.monitor_thread = None
self.loss_spike_limit = loss_spike_limit
self.last_step_loss = -1
def monitor_loss_spike(self, alert_address: str = None, step_count: int = 0, cur_step_loss: float = 0.0):
"""Check loss value, if loss spike occurs, send alert message to Feishu."""
set_env_var(key="LOSS", value=cur_step_loss)
set_env_var(key="STEP_ID", value=step_count)
if self.last_step_loss != -1 and cur_step_loss > self.loss_spike_limit * self.last_step_loss:
send_alert_message(
address=alert_address,
message=(
f"Checking step by step: Loss spike may be happened in step {step_count}, "
f"loss value from {self.last_step_loss} to {cur_step_loss}, please check it."
),
)
self.last_step_loss = cur_step_loss
def monitor_exception(self, alert_address: str = None, excp_info: str = None):
"""Catch and format exception information, send alert message to Feishu."""
filtered_trace = excp_info.split("\n")[-10:]
format_trace = ""
for line in filtered_trace:
format_trace += "\n" + line
send_alert_message(
address=alert_address,
message=f"Catch Exception from {socket.gethostname()} with rank id {gpc.get_global_rank()}:{format_trace}",
)
def handle_sigterm(self, alert_address: str = None):
"""Catch SIGTERM signal, and send alert message to Feishu."""
def sigterm_handler(sys_signal, frame):
print("receive frame: ", frame)
print("receive signal: ", sys_signal)
send_alert_message(
address=alert_address,
message=f"Process received signal {signal} and exited.",
)
signal.signal(signal.SIGTERM, sigterm_handler)
def start_monitor(
self,
job_name: str,
alert_address: str,
monitor_interval_seconds: int = 300,
loss_spike_limit: float = 1.5,
):
"""
Initialize and start monitor thread for checking training job status, loss spike and so on.
Args:
job_name (str): The training job name.
alert_address (str): The Feishu webhook address for sending alert messages.
monitor_interval_seconds (int): The time of monitor interval in seconds, defaults to 300.
loss_spike_limit (float): The limit multiple of current loss to previous loss value, which means loss spike
may be occurs, defaults to 1.5.
"""
# initialize some variables for monitoring
set_env_var(key="JOB_NAME", value=job_name)
# start a monitor thread, periodically check the training status
self.monitor_thread = MonitorTracker(
alert_address=alert_address,
check_interval=monitor_interval_seconds,
loss_spike_limit=loss_spike_limit,
)
def stop_monitor(self):
"""Stop the monitor and alert thread."""
if self.monitor_thread is not None:
self.monitor_thread.stop()
monitor_manager = MonitorManager()
@contextmanager
def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
if alert_address is not None:
try:
monitor_manager.start_monitor(job_name=job_name, alert_address=alert_address)
monitor_manager.handle_sigterm(alert_address=alert_address)
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
yield
finally:
send_alert_message(
address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
)
monitor_manager.stop_monitor()
else:
yield

32
internlm/monitor/utils.py Normal file
View File

@ -0,0 +1,32 @@
import os
from datetime import datetime
def now_time():
return datetime.now().strftime("%b%d_%H-%M-%S")
def set_env_var(key, value):
os.environ[str(key)] = str(value)
def get_job_id():
job_id = "none"
if os.getenv("SLURM_JOB_ID") is not None:
job_id = os.getenv("SLURM_JOB_ID")
elif os.getenv("K8S_WORKSPACE_ID") is not None:
job_id = os.getenv("K8S_WORKSPACE_ID")
return job_id
def get_job_name():
job_name = f"unknown-{now_time()}"
if os.getenv("JOB_NAME") is not None:
job_name = os.getenv("JOB_NAME")
return job_name
def get_job_key():
return f"{get_job_id()}_{get_job_name()}"

View File

@ -29,6 +29,7 @@ from internlm.solver.optimizer.utils import (
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.monitor import send_alert_message
from .utils import compute_norm
@ -543,6 +544,7 @@ class HybridZeroOptimizer(BaseOptimizer):
if found_inf:
if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.")
send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.")
self._grad_store._averaged_gradients = dict()
self.zero_grad()
return False, None

View File

@ -34,18 +34,6 @@ def get_master_node():
return result
def get_process_rank():
proc_rank = -1
if os.getenv("SLURM_PROCID") is not None:
proc_rank = int(os.getenv("SLURM_PROCID"))
elif os.getenv("RANK") is not None:
# In k8s env, we use $RANK.
proc_rank = int(os.getenv("RANK"))
# assert proc_rank != -1, "get_process_rank cant't get right process rank!"
return proc_rank
def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
if torch.is_tensor(norm) and norm.device.type != "cuda":
norm = norm.to(torch.cuda.current_device())

View File

@ -6,8 +6,8 @@ from tqdm import tqdm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.metrics import AccPerplex
from internlm.core.scheduler import SchedulerMetricHook
from internlm.model.metrics import AccPerplex
@contextmanager
@ -90,12 +90,6 @@ def evaluate_on_val_dls(
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
if gpc.config.model.sequence_parallel:
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
)
else:
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
@ -114,7 +108,6 @@ def evaluate_on_val_dls(
assert total_val_bsz % data_cfg.micro_bsz == 0
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
grad_accum_batch_size = data_cfg.micro_bsz
# import pdb; pdb.set_trace()
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer,
grad_accum_size=grad_accum_size,

View File

@ -4,6 +4,7 @@
import copy
import os
import time
from enum import Enum
from typing import Dict
import torch
@ -15,10 +16,22 @@ from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.storage_manager import get_fns, llm_load, llm_save
from internlm.utils.storage_manager import (
get_fns,
get_storage_manager,
llm_load,
llm_save,
)
logger = get_logger(__file__)
quit_signal_handler = None
class CheckpointType(Enum):
NORMAL_CHECKPOINT = 1
SNAPSHOT_CHECKPOINT = 2
def get_model_topology(model):
"""
@ -289,3 +302,77 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
if gpc.is_rank_for_log():
logger.info(f"reload load_scheduler:{lr_scheduler}")
class CheckpointSaveManager:
"""StorageManagerContext"""
def __init__(
self,
ckpt_config,
model,
optimizer,
lr_scheduler,
model_config,
) -> None:
"""
CheckpointSaveManager is used to decide when to store ckpt. If it is an asynchronous
upload mode, you must call wait_async_upload_finish at the end of the program to wait
for the asynchronous ckpt upload to complete.
Args:
ckpt_config (dict): model checkpoint config.
model (nn.module): model obj
optimizer (object): optimzier obj.
lr_scheduler (object): lr_scheduler obj.
model_config (dict): model config.
"""
self.enable_save_ckpt = ckpt_config.enable_save_ckpt
self.checkpoint_every = ckpt_config.checkpoint_every
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
self.storage_manager = get_storage_manager()
self.snapshot_counter = 0
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.model_config = model_config
def try_save_checkpoint(self, train_state):
if not self.enable_save_ckpt:
return
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
if train_state.step_count % self.checkpoint_every == 0:
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
if save_ckpts is False:
if quit_signal_handler is not None:
save_ckpts, save_type = quit_signal_handler(train_state)
if save_ckpts:
# Wait for the previous round of asynchronous upload storage to complete.
self.storage_manager.wait()
if save_type == CheckpointType.SNAPSHOT_CHECKPOINT:
# Snapshot number, with only two snapshots written alternately.
self.snapshot_counter = (self.snapshot_counter + 1) % 2
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
else:
save_ckpt_folder = self.save_ckpt_folder
save_checkpoint(
folder=save_ckpt_folder,
model=self.model,
optimizer=self.optimizer,
scheduler=self.lr_scheduler,
train_state=train_state,
model_config=self.model_config,
)
def wait_async_upload_finish(self):
"""wait for all checkpoint uploads to be completed"""
self.storage_manager.wait()
torch.distributed.barrier()

View File

@ -1,18 +1,26 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import asyncio
import concurrent.futures
import hashlib
import io
import os
import pickle
import re
import socket
from enum import Enum
from typing import Any, Dict, List, Union
import stat
from asyncio import InvalidStateError
from asyncio.tasks import ALL_COMPLETED
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Union
import boto3
import botocore
import torch
import torch.distributed as dist
from internlm.core.context import global_context as gpc
from internlm.utils.common import SingletonMeta
from internlm.utils.logger import get_logger
@ -41,10 +49,6 @@ def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
class CheckpointType(Enum):
NORMAL_CHECKPOINT = 1
class StorageClient:
"""
StorageClient as a client for s3 storage access.
@ -54,7 +58,7 @@ class StorageClient:
self.handler = handler
@staticmethod
def load(client, load_path: str, map_location):
def load(client, load_path: str, *args, **kwargs):
raise NotImplementedError
@staticmethod
@ -71,25 +75,51 @@ class StorageClient:
class Boto3MetaInfo:
def __init__(self, client: StorageClient, bucket_name: str, endpoint: str, file_path: str) -> None:
self.client = client
"""Boto3 meta info for save/load etc."""
def __init__(
self,
is_async,
handler: StorageClient,
bucket_name: str,
endpoint: str,
file_path: str,
async_upload_fn: callable,
local_nvme_path=None,
) -> None:
self.is_async = is_async
self.client = handler
self.bucket_name = bucket_name
self.endpoint = endpoint
self.file_path = file_path
self.async_upload_fn = async_upload_fn
self.local_nvme_path = local_nvme_path
def __str__(self) -> str:
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
local_nvme_path: {self.local_nvme_path}"
class LocalMetaInfo:
def __init__(self, client: StorageClient, dest_path: str) -> None:
self.client = client
"""Local meta info for save/load etc."""
def __init__(self, handler: StorageClient, dest_path: str) -> None:
self.is_async = False
self.client = handler
self.dest_path = dest_path
self.async_upload_fn = None
def unpack_meta(meta):
args = []
is_async = meta.is_async
for k, v in meta.__dict__.items():
if k == "endpoint":
if k in ("endpoint", "async_upload_fn", "is_async"):
continue
if not is_async and k in ("local_nvme_path",):
continue
args.append(v)
return args
@ -101,21 +131,6 @@ def compute_file_md5_by_chunk(file_name: str):
return hash_md5.hexdigest()
def get_boto3_meta(fp: str) -> Boto3MetaInfo:
assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url"
parts = fp.lstrip("s3://").split(os.path.sep)
match = boto3_url_re.match(parts[0])
assert match is not None, f"url '{fp}' is not a valid boto3 url"
bucket_name, endpoint = match.group(1), match.group(2)
endpoint = "http://" + endpoint + ":80"
return Boto3MetaInfo(None, bucket_name, endpoint, os.path.sep.join(parts[1:]))
def get_local_meta(fp: str) -> LocalMetaInfo:
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
return LocalMetaInfo(None, fp)
class Boto3Client(StorageClient):
"""
Boto3Client
@ -169,7 +184,9 @@ class Boto3Client(StorageClient):
)
@staticmethod
def sync_upload_fileobj(handler, bucket_name: str, fp: str, *args, saved_obj=None, **kwargs):
def sync_upload_fileobj(
handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs
): # pylint: disable=W0613
assert saved_obj is not None, "saved_obj is None!"
try:
with io.BytesIO() as f:
@ -182,7 +199,14 @@ class Boto3Client(StorageClient):
) from exc
@staticmethod
def load(handler, bucket_name: str, fp: str, *args, map_location="cpu", **kwargs) -> Dict:
def load(
handler,
bucket_name: str,
fp: str,
local_nvme_path: str, # pylint: disable=W0613
*args,
**kwargs,
) -> Dict:
"""
Args:
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
@ -191,7 +215,7 @@ class Boto3Client(StorageClient):
with io.BytesIO() as f:
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
f.seek(0)
states = torch.load(f, *args, map_location=map_location, **kwargs)
states = torch.load(f, *args, **kwargs)
except handler.botocore.exceptions.EndpointConnectionError as exc:
raise RuntimeError(
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
@ -199,15 +223,11 @@ class Boto3Client(StorageClient):
return states
@staticmethod
def assert_fp_exists(
handler,
bucket_name: str,
fp: str,
):
def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
@staticmethod
def get_fns(handler, bucket_name: str, fp: str):
def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613
"""
Ref: https://stackoverflow.com/questions/54314563/
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
@ -222,6 +242,22 @@ class Boto3Client(StorageClient):
folder_name_list.append(fp.rsplit("/", maxsplit=1)[1])
return folder_name_list
@staticmethod
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
try:
with open(local_nvme_path, "rb") as f:
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
except handler.botocore.exceptions.EndpointConnectionError as exc:
raise RuntimeError(
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
) from exc
except Exception as e:
raise e
@staticmethod
def delete_obj(handler, fp: str):
raise NotImplementedError("boto3 not support delete_obj")
class LocalClient(StorageClient):
"""
@ -241,11 +277,11 @@ class LocalClient(StorageClient):
torch.save(saved_obj, fp, *args, **kwargs)
@staticmethod
def load(handler, fp: str, *args, map_location="cpu", **kwargs):
def load(handler, fp: str, *args, **kwargs): # pylint: disable=W0613
assert isinstance(handler, LocalClient)
assert os.path.exists(fp), f"{fp} is not found!"
with open(fp, "rb") as f:
states = torch.load(f, map_location=map_location, *args, **kwargs)
states = torch.load(f, *args, **kwargs)
return states
@staticmethod
@ -267,9 +303,77 @@ class LocalClient(StorageClient):
os.remove(fp)
def get_tmp_file_name(tmp_local_folder: str, fp: str):
"""
It should be noted that all our temporary files will be stored in the same folder,
so the file name passed upstream must be unique.
"""
base_path = os.path.join(tmp_local_folder, fp.split("/")[-1])
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
pid = os.getpid()
# step = self.step_counter
return "-".join([base_path, current_time, str(pid)]) + ".tmpfile" # , str(step)
def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaInfo:
assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url"
parts = fp.lstrip("s3://").split(os.path.sep)
match = boto3_url_re.match(parts[0])
assert match is not None, f"url '{fp}' is not a valid boto3 url"
bucket_name, endpoint = match.group(1), match.group(2)
endpoint = "http://" + endpoint + ":80"
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
return Boto3MetaInfo(
is_async=is_async,
handler=None,
bucket_name=bucket_name,
endpoint=endpoint,
file_path=os.path.sep.join(parts[1:]),
async_upload_fn=Boto3Client.async_upload_fileobj,
local_nvme_path=tmp_step_file,
)
def get_local_meta(fp: str) -> LocalMetaInfo:
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
return LocalMetaInfo(None, fp)
def get_mount_point_free_size(path: str):
"""
Returns the remaining space of the temporary storage mount point as a percentage.
Args:
path (str): temporary storage folder path.
Raises:
FileNotFoundError: If the temporary storage folder does not exist,
an error will be reported
"""
if os.path.exists(path):
st = os.statvfs(path)
# f_bavail: Number of free blocks for unprivileged users.
# f_bsize: Filesystem block size.
# return unit is TB.
return st.f_bavail * st.f_bsize / (1024**3)
def check_tmp_folder_accessibility(tmp_local_folder: str):
"""
Check access permissions for temporary storage.
"""
ret = True
if os.path.exists(tmp_local_folder):
ret &= os.access(tmp_local_folder, os.W_OK)
ret &= os.access(tmp_local_folder, os.R_OK)
if ret is False:
error_str = f'{socket.gethostname()} dose not have read and write permissions on {tmp_local_folder}"'
raise RuntimeError(error_str)
class StorageManager(metaclass=SingletonMeta):
"""
Storage Manager for saving or loading checkpoint.
TODO: add a thread to poll the asynchronous storage state.
"""
BACKEND_TYPE = {"boto3", "local"}
@ -279,9 +383,40 @@ class StorageManager(metaclass=SingletonMeta):
}
CLI_DICT = {}
def __init__(self) -> None:
def __init__(self, enable_save, tmp_local_folde="/dev/shm/test/", async_mode=True, n_async_workers=8) -> None:
self._exception_list = []
self._to_be_del_files = []
self._async_stack = []
self.upload_count = 0
self.tmp_local_folder = tmp_local_folde
self.async_mode = async_mode
self.has_warning = False
if enable_save and self.async_mode:
self._async_loop = asyncio.new_event_loop()
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers)
check_tmp_folder_accessibility(os.path.dirname(self.tmp_local_folder))
# Try to create tmp folder
try:
os.makedirs(self.tmp_local_folder, exist_ok=True)
os.chmod(self.tmp_local_folder, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
except FileExistsError:
pass
# In case it is a directory created by other users, we check the permissions again.
check_tmp_folder_accessibility(self.tmp_local_folder)
# Try to clean tmp folder's empty folder.
self.try_delete_tmpfile(self.tmp_local_folder)
# Avaliable storeage space check.
free_size = get_mount_point_free_size(self.tmp_local_folder)
if free_size < 0.1:
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
"""
example:
@ -301,7 +436,7 @@ class StorageManager(metaclass=SingletonMeta):
meta_info = get_local_meta(path)
backend_key = backend
elif backend == "boto3":
meta_info = get_boto3_meta(path)
meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode)
backend_key = backend + ":" + meta_info.endpoint
init_args = (meta_info.endpoint,)
if (
@ -310,10 +445,12 @@ class StorageManager(metaclass=SingletonMeta):
or "HTTP_PROXY" in os.environ
or "HTTPS_PROXY" in os.environ
):
raise RuntimeWarning(
if not self.has_warning:
logger.warning(
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
the proxy may make boto3 unavailable or affect performance."
the proxy may make boto3 unavailable or affect performance."
)
self.has_warning = True
assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}"
@ -333,19 +470,137 @@ the proxy may make boto3 unavailable or affect performance."
meta = self._get_client(path=folder)
return meta.client.get_fns(*unpack_meta(meta))
def save(self, save_path: str, saved_obj: Any, *args, **kwargs):
def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs):
meta = self._get_client(path=save_path)
if async_upload is None:
async_upload = self.async_mode
if async_upload:
assert (
self.tmp_local_folder
), "StorageManager is not setted tmp_local_folder, so async save cannot be performed."
tmp_step_file = meta.local_nvme_path
self._to_be_del_files.append(tmp_step_file)
with open(tmp_step_file, "wb") as f:
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
else:
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
self.upload_count += 1
def load(self, load_path: str, *args, map_location="cpu", **kwargs) -> Any:
def load(self, load_path: str, *args, **kwargs) -> Any:
self.wait()
meta = self._get_client(path=load_path)
return meta.client.load(*unpack_meta(meta), map_location=map_location, *args, **kwargs)
return meta.client.load(*unpack_meta(meta), *args, **kwargs)
def delete_obj(self, fp: str):
meta = self._get_client(path=fp)
meta.client.delete_obj(*unpack_meta(meta))
def _del_tmp_folder(self):
for fp in self._to_be_del_files:
try:
os.remove(fp)
except FileNotFoundError:
pass
except SystemError as e:
logger.error(f'delete file: {fp}, failed for reason:"{e}"')
else:
pass
storage_manager = StorageManager()
def try_delete_tmpfile(self, tmp_dir: str):
"""Delete temporary files in tmp_dir."""
for filename in os.listdir(tmp_dir):
if filename.endswith(".tmpfile"):
file_path = os.path.join(tmp_dir, filename)
try:
os.remove(file_path)
logger.info(f"Delete tmpfile: {file_path}")
except OSError:
# Ignore deletion errors
pass
async def _sync_tasks(self) -> Awaitable[None]:
if not self._async_stack:
return
await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED)
for task in self._async_stack:
try:
task.exception()
except InvalidStateError:
continue
except Exception as e:
file_id = len(self._exception_list)
self._exception_list.append((e, file_id))
logger.error(f"File: {self._to_be_del_files[file_id]}, " f"upload failed with {e}")
self._async_stack.clear()
def async_executor(self, fn: Callable, *args, **kwargs) -> None:
"""
Overview:
Execute task in background, then apppend the future instance in _async_stack.
Arguments:
- fn (:obj:`Callable`): Synchronization fuction.
"""
if not self._async_loop:
raise RuntimeError("Event loop was not initialized, please call this function in async or parallel mode")
t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs)
self._async_stack.append(t)
def wait(self) -> bool:
"""Wait for async operations to complete."""
if not self.async_mode:
return
if self._async_loop:
self._async_loop.run_until_complete(self._sync_tasks())
if self._exception_list:
for file_id, error_msg in self._exception_list:
logger.error(
f"Node:{socket.gethostname()}, Error: Checkpoint {self._to_be_del_files[file_id]} "
f"failed on step {self.upload_count}: {error_msg}"
)
# TODO: Re-upload in sync mode
raise RuntimeError(
f"Failed to upload {self._to_be_del_files[file_id]} " f"on step {self.upload_count}: {error_msg}"
)
self._del_tmp_folder()
self._exception_list.clear()
self._to_be_del_files.clear()
if gpc.is_rank_for_log():
logger.info("all async uploads succeeded!")
self.upload_count += 1
storage_manager: StorageManager = None
def init_storage_manager(ckpt_config):
global storage_manager
storage_manager = StorageManager(
ckpt_config.enable_save_ckpt,
tmp_local_folde=ckpt_config.async_upload_tmp_folder,
async_mode=ckpt_config.async_upload,
)
def get_storage_manager():
assert storage_manager is not None, "storage_manager has not been init!"
return storage_manager
def wait_async_upload_finish():
dist.barrier()
storage_manager.wait()

View File

@ -30,6 +30,8 @@ from internlm.data.packed_dataset import (
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer
@ -37,7 +39,6 @@ from internlm.utils.common import (
BatchSkipper,
get_master_node,
get_megatron_flops,
get_process_rank,
launch_time,
parse_args,
)
@ -45,12 +46,12 @@ from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_paral
from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import (
CheckpointSaveManager,
load_context,
load_model_checkpoint,
load_optimizer_checkpoint,
load_sampler,
load_scheduler,
save_checkpoint,
)
from internlm.utils.parallel import (
get_parallel_log_file_name,
@ -92,6 +93,15 @@ def initialize_distributed_env(config: str, launcher: str = "slurm", master_port
def initialize_llm_logger(start_time: str):
"""
Initialize customed uniscale logger.
Args:
start_time (str): The launch time of current training job.
Returns: The instance of uniscale logger.
"""
uniscale_logger = initialize_uniscale_logger(
job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
)
@ -213,6 +223,8 @@ def get_train_data_loader(num_worker: int = 0):
def get_validation_data_loader(num_worker: int = 0):
"""Generate and return the validation data loader."""
data_cfg = gpc.config.data
if not data_cfg.valid_folder:
@ -327,6 +339,8 @@ def record_current_batch_training_metrics(
Print some training metrics of current batch.
"""
set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
if success_update in (0, True):
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
if is_no_pp_or_last_stage():
@ -405,12 +419,11 @@ def record_current_batch_training_metrics(
else:
logger.info(line)
# if loss spike occurs, send alert info to feishu
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
def main(args):
# initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None
# init setting
skip_batches = gpc.config.data.skip_batches
total_steps = gpc.config.data.total_steps
@ -419,11 +432,6 @@ def main(args):
label_smoothing = gpc.config.loss.label_smoothing
lr = gpc.config.adam.lr
# ckpt setting
save_ckpt_folder = gpc.config.ckpt.save_ckpt_folder
enable_save_ckpt = gpc.config.ckpt.enable_ckpt
checkpoint_every = gpc.config.ckpt.checkpoint_every
load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None)
load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None)
@ -477,8 +485,8 @@ def main(args):
model_load_path = load_model_only_folder
else:
logger.info(
f"===========New Run {current_time} on host:{socket.gethostname()},"
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
)
@ -514,6 +522,14 @@ def main(args):
if load_optimizer:
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
ckpt_save_manager = CheckpointSaveManager(
ckpt_config=gpc.config.ckpt,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
model_config=gpc.config.model,
)
# initialize metric for calculating accuracy and perplexity
metric = AccPerplex(
device=torch.cuda.current_device(),
@ -594,6 +610,9 @@ def main(args):
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message(
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
)
# calculate and record the training metrics, eg. loss, accuracy and so on.
record_current_batch_training_metrics(
@ -629,26 +648,27 @@ def main(args):
)
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
# save batch sampler that tracks the true consumed samples
if enable_save_ckpt and train_state.step_count % checkpoint_every == 0:
save_checkpoint(
folder=save_ckpt_folder,
model=model,
optimizer=optimizer,
scheduler=lr_scheduler,
train_state=train_state,
model_config=gpc.config.model,
)
# # save batch sampler that tracks the true consumed samples
ckpt_save_manager.try_save_checkpoint(train_state)
# wait for all checkpoint uploads to be completed
dist.barrier()
ckpt_save_manager.wait_async_upload_finish()
if __name__ == "__main__":
args = parse_args()
hostname = socket.gethostname()
# initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None
# initialize monitor manager context
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
try:
main(args)
except Exception:
print(f"Raise exception from {socket.gethostname()} with proc id: {get_process_rank()}")
traceback.print_exc()
logger.error(
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
exc_info=traceback.format_exc(),
)
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())