pull/520/head
lijiaxing 2023-12-01 11:10:04 +08:00
parent 06cdcc3654
commit 3b7fb97e04
4 changed files with 79 additions and 40 deletions

View File

@ -720,9 +720,10 @@ class CheckpointManager:
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"])
torch.distributed.barrier()
# test storage setting is ok.
# if self.enable_save_ckpt:
# self.try_ping_storage()
if self.enable_save_ckpt:
self.try_ping_storage()
def quit_signal_handler(self, train_state) -> bool:
"""
@ -1016,7 +1017,7 @@ now step_count is {train_state.step_count}",
self.storage_manager.latest_save_step = step
def try_ping_storage(self):
if gpc.get_global_rank() % 8 == 0:
if gpc.is_rank_for_log():
buff = torch.ones((1, 64, 64), dtype=torch.bfloat16)
test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping")
self.storage_manager.save(test_fn, buff)

View File

@ -18,10 +18,6 @@ from typing import Any, Awaitable, Callable, Dict, List, Union
import torch
import torch.distributed as dist
from internlm.core.context import global_context as gpc
from internlm.utils.common import SingletonMeta
from internlm.utils.logger import get_logger
try:
import boto3
import botocore
@ -30,8 +26,7 @@ except ImportError:
try:
import tos
from tos import DataTransferType
from tos.utils import SizeAdapter, MergeProcess
from tos.utils import SizeAdapter
except ImportError:
pass
@ -43,7 +38,26 @@ except ImportError:
pass
logger = get_logger(__file__)
class Logger:
"Dummy logger"
def info(self, mesage: str):
print(f"Info: {mesage}", flush=True)
def warning(self, mesage: str):
print(f"Warning: {mesage}", flush=True)
def error(self, mesage: str):
print(f"Error: {mesage}", flush=True)
try:
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
except ImportError:
logger = Logger()
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
volc_url_re = re.compile(r"^(.*?)\.(.*)$")
@ -68,7 +82,13 @@ def llm_load(fp: str, **kwargs):
def llm_save(save_path: str, saved_obj: Any, **kwargs):
storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs)
def is_rank_for_log():
if dist.is_initialized():
return dist.get_rank() % 8 == 0
return True
class StorageClient:
"""
@ -271,21 +291,21 @@ def compute_file_md5_by_chunk(file_name: str):
def try_get_storage_backend(path: str):
if path.startswith("s3:"):
if gpc.is_rank_for_log():
if is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
return "boto3", path
elif path.startswith("vc:"):
if gpc.is_rank_for_log():
if is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of volc.")
return "volc", path
elif path.startswith("ali:"):
if gpc.is_rank_for_log():
if is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of ali.")
return "oss2", path
else:
sre = path.split(":", maxsplit=1)
if len(sre) == 1:
if gpc.is_rank_for_log():
if is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.")
return "local", sre[0]
else:
@ -403,7 +423,7 @@ class Boto3Client(StorageClient):
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
return list(set(folder_name_list))
else:
if gpc.is_rank_for_log():
if is_rank_for_log():
logger.warning(f"'{fp}' not found!")
return None
@ -534,7 +554,7 @@ class VolcClient(StorageClient):
return list(set(folder_name_list))
else:
if gpc.is_rank_for_log():
if is_rank_for_log():
logger.warning(f"'{fp}' not found!")
return None
@ -550,22 +570,25 @@ class VolcClient(StorageClient):
parts = []
# 上传分片数据
logger.info('Begin multipart upload of one file')
with open(local_nvme_path, 'rb') as f:
with open(local_nvme_path, "rb") as f:
part_number = 1
offset = 0
while offset < total_size:
num_to_upload = min(part_size, total_size - offset)
out = handler.client.upload_part(bucket_name, fp, upload_id, part_number,
content=SizeAdapter(f, num_to_upload, init_offset=offset))
out = handler.client.upload_part(
bucket_name,
fp,
upload_id,
part_number,
content=SizeAdapter(f, num_to_upload, init_offset=offset),
)
parts.append(out)
offset += num_to_upload
part_number += 1
# 完成分片上传任务
handler.client.complete_multipart_upload(bucket_name, fp, upload_id, parts)
logger.info('Finish multipart upload of one file')
except handler.handler.exceptions.TosClientError as exc:
raise RuntimeError(
f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}"
@ -600,10 +623,10 @@ class AliClient(StorageClient):
"""Ali object/file storage management class
Args:
access_key (str): Ali access key ID.
access_key (str): Ali access key ID.s
secret_key (str): Ali secret access key.
endpoint (str): Ali tos endpoint.
region (str): Ali tos region.
bucket_name (str): Ali tos bucket_name.
"""
super().__init__(oss2)
@ -664,7 +687,7 @@ class AliClient(StorageClient):
return list(set(folder_name_list))
else:
if gpc.is_rank_for_log():
if is_rank_for_log():
logger.warning(f"'{fp}' not found!")
return None
@ -672,24 +695,25 @@ class AliClient(StorageClient):
def async_upload_fileobj(handler, fp: str, local_nvme_path: str):
try:
# handler.client.put_object_from_file(fp, local_nvme_path)
total_size = os.path.getsize(local_nvme_path)
part_size = determine_part_size(total_size, preferred_size=5 * 1024 * 1024)
upload_id = handler.client.init_multipart_upload(fp).upload_id
parts = []
with open(local_nvme_path, 'rb') as fileobj:
with open(local_nvme_path, "rb") as fileobj:
part_number = 1
offset = 0
while offset < total_size:
num_to_upload = min(part_size, total_size - offset)
# 调用SizedFileAdapter(fileobj, size)方法会生成一个新的文件对象,重新计算起始追加位置。
result = handler.client.upload_part(fp, upload_id, part_number,
SizedFileAdapter(fileobj, num_to_upload))
result = handler.client.upload_part(
fp, upload_id, part_number, SizedFileAdapter(fileobj, num_to_upload)
)
parts.append(PartInfo(part_number, result.etag))
offset += num_to_upload
part_number += 1
headers = dict()
handler.client.complete_multipart_upload(fp, upload_id, parts, headers=headers)
except Exception as e:
@ -733,7 +757,7 @@ class LocalClient(StorageClient):
@staticmethod
def get_fns(folder):
if not os.path.exists(folder):
if gpc.is_rank_for_log():
if is_rank_for_log():
logger.warning(f"'{folder}' not found!")
return None
else:
@ -865,6 +889,23 @@ def check_tmp_folder_accessibility(tmp_local_folder: str):
raise RuntimeError(error_str)
class SingletonMeta(type):
"""
Singleton Meta.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
else:
assert (
len(args) == 0 and len(kwargs) == 0
), f"{cls.__name__} is a singleton class and a instance has been created."
return cls._instances[cls]
class StorageManager(metaclass=SingletonMeta):
"""
Storage Manager for saving or loading checkpoint.
@ -948,7 +989,7 @@ class StorageManager(metaclass=SingletonMeta):
or "HTTP_PROXY" in os.environ
or "HTTPS_PROXY" in os.environ
):
if not self.has_warning and gpc.is_rank_for_log():
if not self.has_warning and is_rank_for_log():
logger.warning(
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
the proxy may make boto3 unavailable or affect performance."
@ -967,7 +1008,7 @@ class StorageManager(metaclass=SingletonMeta):
or "HTTP_PROXY" in os.environ
or "HTTPS_PROXY" in os.environ
):
if not self.has_warning and gpc.is_rank_for_log():
if not self.has_warning and is_rank_for_log():
logger.warning(
"HTTP/HTTPS proxy is detected when using volc, incorrectly setting \
the proxy may make volc unavailable or affect performance."
@ -986,7 +1027,7 @@ class StorageManager(metaclass=SingletonMeta):
or "HTTP_PROXY" in os.environ
or "HTTPS_PROXY" in os.environ
):
if not self.has_warning and gpc.is_rank_for_log():
if not self.has_warning and is_rank_for_log():
logger.warning(
"HTTP/HTTPS proxy is detected when using oss2, incorrectly setting \
the proxy may make oss2 unavailable or affect performance."
@ -1132,7 +1173,7 @@ class StorageManager(metaclass=SingletonMeta):
self._to_be_del_files.clear()
self.async_task_peeding = False
if gpc.is_rank_for_log():
if is_rank_for_log():
self.upload_count += 1
if self.async_mode and self.latest_save_folder:
self.save(

View File

@ -8,7 +8,7 @@ import torch
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import Config
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
from internlm.utils.common import SingletonMeta
from internlm.utils.storage_manager import SingletonMeta
OSS_NAME = os.environ.get("OSS_BUCKET_NAME")
OSS_IP = os.environ.get("OSS_IP")

View File

@ -100,7 +100,6 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg
init_storage_manager,
llm_load,
llm_save,
wait_async_upload_finish,
)
ckpt_config = Config(ckpt_config)
@ -118,8 +117,6 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg
tobj = torch.rand(64, 64)
save_fn = os.path.join(ckpt_config.save_folder, "test.pt")
llm_save(save_fn, tobj)
if ckpt_config.test_id == 0:
wait_async_upload_finish()
check_folder(save_fn)
assert get_fns(ckpt_config.save_folder)[0] == "test.pt"
load_obj = llm_load(save_fn, map_location="cpu")