diff --git a/doc/code-docs/source/checkpoint.rst b/doc/code-docs/source/checkpoint.rst index cd9b755..de0c4cb 100644 --- a/doc/code-docs/source/checkpoint.rst +++ b/doc/code-docs/source/checkpoint.rst @@ -39,7 +39,7 @@ CheckpointManager load_ckpt_folder=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internlm"), auto_resume=False, # disable auto-resume, internlm will load model checkpoint from the path of 'load_ckpt_folder'. checkpoint_every=CHECKPOINT_EVERY, - async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload=True, # async ckpt upload. (only work for boto3 and volc ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. ) @@ -67,7 +67,9 @@ InternLM对config中出现的所有存储路径都遵循以下的路径格式约 1. 如果需要使用boto3的路径,需要在运行前提前导入 ``S3_ACCESS_KEY_ID`` 和 ``S3_SECRET_ACCESS_KEY_ID`` 这两个环境变量。 -2. bucket的endpoint一般分为Inside IP和Outside IP,如果可以尽量使用inside IP,会获得更佳的存储速度。 +2. 如果需要使用volc的路径,需要在运行前提前导入 ``VOLC_ACCESS_KEY_ID`` 和 ``VOLC_SECRET_ACCESS_KEY_ID`` 这两个环境变量。 + +3. bucket的endpoint一般分为Inside IP和Outside IP,如果可以尽量使用inside IP,会获得更佳的存储速度。 @@ -114,7 +116,7 @@ config.ckpt 中相关的参数: - ``async_upload_tmp_folder``: 异步上传临时存储路径。参数类型 ``str/None``, 默认值为 ``/dev/shm/{JOB_NAME}_tmp_ckpt/``。 -需要注意的是,异步上传功能仅在backend为boto3时才会有效果,bcakend为local时只支持同步存储。 +需要注意的是,异步上传功能仅在backend为boto3或volc时才会有效果,bcakend为local时只支持同步存储。 ``async_upload_tmp_folder`` 设置的的原则为尽量设置为计算节点的local目录,这样才可以获得最佳的异步上传速度,一般来说建议为 ``/dev/shm`` 或 ``/nvme`` 下的路径,如果使用同步上传,则该路径可不给。 diff --git a/doc/imgs/ckpt_path_format_CN.png b/doc/imgs/ckpt_path_format_CN.png index 0307d22..0b21f54 100644 Binary files a/doc/imgs/ckpt_path_format_CN.png and b/doc/imgs/ckpt_path_format_CN.png differ diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index a3f9122..89f8023 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -25,6 +25,7 @@ from internlm.utils.logger import get_logger try: import boto3 import botocore + import tos except ImportError: pass @@ -32,6 +33,7 @@ except ImportError: logger = get_logger(__file__) boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)") +volc_url_re = re.compile(r"^(.*?)\.(.*)$") MB = 1024**2 @@ -122,6 +124,47 @@ local_nvme_path: {self.local_nvme_path}" return meta.client, meta.bucket_name, meta.file_path +class VolcMetaInfo: + """Volc meta info for save/load etc.""" + + def __init__( + self, + is_async, + handler: StorageClient, + bucket_name: str, + endpoint: str, + region: str, + file_path: str, + async_upload_fn: callable, + local_nvme_path=None, + ) -> None: + # all need info. + self.client = handler + self.bucket_name = bucket_name + self.file_path = file_path + # only save need info. + self.local_nvme_path = local_nvme_path + self.is_async = is_async + self.endpoint = endpoint + self.region = region + self.async_upload_fn = async_upload_fn + + def __str__(self) -> str: + return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \ +region:{self.region}, local_nvme_path: {self.local_nvme_path}" + + @staticmethod + def unpack_volc_save_meta(meta): + if meta.is_async: + return meta.client, meta.bucket_name, meta.file_path, meta.local_nvme_path + else: + return meta.client, meta.bucket_name, meta.file_path + + @staticmethod + def unpack_volc_nosave_meta(meta): + return meta.client, meta.bucket_name, meta.file_path + + class LocalMetaInfo: """Local meta info for save/load etc.""" @@ -139,18 +182,22 @@ class LocalMetaInfo: return (meta.file_path,) -def unpack_save_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]): +def unpack_save_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]): if isinstance(meta, Boto3MetaInfo): return Boto3MetaInfo.unpack_boto3_save_meta(meta) + elif isinstance(meta, VolcMetaInfo): + return VolcMetaInfo.unpack_volc_save_meta(meta) elif isinstance(meta, LocalMetaInfo): return LocalMetaInfo.unpack_local_save_meta(meta) else: raise ValueError(f"unkonwn meta info: {type(meta)}") -def unpack_nosave_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]): +def unpack_nosave_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]): if isinstance(meta, Boto3MetaInfo): return Boto3MetaInfo.unpack_boto3_nosave_meta(meta) + elif isinstance(meta, VolcMetaInfo): + return VolcMetaInfo.unpack_volc_nosave_meta(meta) elif isinstance(meta, LocalMetaInfo): return LocalMetaInfo.unpack_local_nosave_meta(meta) else: @@ -170,6 +217,10 @@ def try_get_storage_backend(path: str): if gpc.is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.") return "boto3", path + elif path.startswith("vc:"): + if gpc.is_rank_for_log(): + logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of volc.") + return "volc", path else: sre = path.split(":", maxsplit=1) if len(sre) == 1: @@ -312,6 +363,143 @@ class Boto3Client(StorageClient): raise NotImplementedError("boto3 not support delete_obj") +class VolcClient(StorageClient): + """ + VolcClient + """ + + def __init__( + self, + endpoint: str, + region: str, + ) -> None: + """Volc object/file storage management class + + Args: + access_key (str): Volc access key ID. + secret_key (str): Volc secret access key. + endpoint (str): Volc tos endpoint. + region (str): Volc tos region. + + """ + super().__init__(tos) + + try: + access_key = os.environ["VOLC_ACCESS_KEY_ID"] + secret_key = os.environ["VOLC_SECRET_ACCESS_KEY_ID"] + except KeyError as exc: + raise RuntimeError( + "Please set 'VOLC_ACCESS_KEY_ID' and 'VOLC_SECRET_ACCESS_KEY_ID'", + "using environment variable!", + ) from exc + + self.client = self.handler.TosClientV2(access_key, secret_key, endpoint, region) + + @staticmethod + def sync_upload_fileobj(handler, bucket_name: str, fp: str, saved_obj=None, **kwargs): + assert saved_obj is not None, "saved_obj is None!" + try: + with io.BytesIO() as f: + torch.save(saved_obj, f, **kwargs) + f.seek(0) + handler.client.put_object(bucket_name, fp, content=f) + except handler.handler.exceptions.TosClientError as exc: + raise RuntimeError( + f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}" + ) from exc + except handler.handler.exceptions.TosServerError as exc: + raise RuntimeError( + f"Volc Network Error: fail with server error, code: {exec.code}", + f"error with request id: {exec.request_id}", + f"error with message: {exec.message}", + f"error with http code: {exec.status_code}", + ) from exc + + @staticmethod + def load(handler, bucket_name: str, fp: str, **kwargs) -> Dict: + """ + Args: + fp (str): Path to save, eg. vc://opennlplab/model_weights/xxx/ddd.pt + """ + try: + object_stream = handler.client.get_object(bucket_name, fp) + buffer = io.BytesIO(object_stream.read()) + states = torch.load(buffer, **kwargs) + except handler.handler.exceptions.TosClientError as exc: + raise RuntimeError( + f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}" + ) from exc + except handler.handler.exceptions.TosServerError as exc: + raise RuntimeError( + f"Volc Network Error: fail with server error, code: {exec.code}", + f"error with request id: {exec.request_id}", + f"error with message: {exec.message}", + f"error with http code: {exec.status_code}", + ) from exc + + return states + + @staticmethod + def assert_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613 + assert len(list(handler.client.list_objects_type2(bucket_name, prefix=fp).contents)) > 0, fp + + @staticmethod + def is_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613 + re = handler.client.list_objects_type2(bucket_name, prefix=fp) + if hasattr(re, "contents"): + return len(list(re.contents)) > 0 + else: + return False + + @staticmethod + def get_fns(handler, bucket_name: str, fp: str): + if VolcClient.is_fp_exists(handler, bucket_name, fp): + folder_name_list = [] + result = handler.client.list_objects_type2(bucket_name, prefix=fp) + if hasattr(result, "contents"): + for iterm in result.contents: + pth = iterm.key + folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) + + while result.is_truncated: + result = handler.client.list_objects_type2( + bucket_name, prefix=fp, continuation_token=result.next_continuation_token + ) + if hasattr(result, "contents"): + for iterm in result.contents: + pth = iterm.key + folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) + + return list(set(folder_name_list)) + + else: + if gpc.is_rank_for_log(): + logger.warning(f"'{fp}' not found!") + return None + + @staticmethod + def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str): + try: + handler.client.put_object_from_file(bucket_name, fp, local_nvme_path) + except handler.handler.exceptions.TosClientError as exc: + raise RuntimeError( + f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}" + ) from exc + except handler.handler.exceptions.TosServerError as exc: + raise RuntimeError( + f"Volc Network Error: fail with server error, code: {exec.code}", + f"error with request id: {exec.request_id}", + f"error with message: {exec.message}", + f"error with http code: {exec.status_code}", + ) from exc + except Exception as e: + raise e + + @staticmethod + def delete_obj(handler, fp: str): + raise NotImplementedError("volc not support delete_obj") + + class LocalClient(StorageClient): """ Storage Client for local NFS. @@ -388,8 +576,35 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI ) +def get_volc_meta(fp: str, tmp_local_folder: str, is_async: bool) -> VolcMetaInfo: + assert fp.startswith("vc://"), f"Path '{fp}' is not a volc url" + parts = fp.lstrip("vc://").split(os.path.sep) + match = volc_url_re.match(parts[0]) + assert match is not None, f"url '{fp}' is not a valid volc url" + bucket_name, endpoint = match.group(1), match.group(2) + temp_part = endpoint.split(".") + endpoint = ".".join(temp_part[1:]) + region = temp_part[1].split("-") + region = "-".join(region[1:]) + + if is_async: + tmp_step_file = get_tmp_file_name(tmp_local_folder, fp) + else: + tmp_step_file = None + return VolcMetaInfo( + is_async=is_async, + handler=None, + bucket_name=bucket_name, + endpoint=endpoint, + region=region, + file_path=os.path.sep.join(parts[1:]), + async_upload_fn=VolcClient.async_upload_fileobj, + local_nvme_path=tmp_step_file, + ) + + def get_local_meta(fp: str) -> LocalMetaInfo: - assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path" + assert not fp.startswith("s3://") and not fp.startswith("vc://"), f"Path '{fp}' is not a local path" return LocalMetaInfo(fp) @@ -430,10 +645,11 @@ class StorageManager(metaclass=SingletonMeta): TODO: add a thread to poll the asynchronous storage state. """ - BACKEND_TYPE = {"boto3", "local"} + BACKEND_TYPE = {"boto3", "local", "volc"} BACKEND_INIT_METHOD = { "boto3": Boto3Client, "local": LocalClient, + "volc": VolcClient, } CLI_DICT = {} @@ -476,11 +692,12 @@ class StorageManager(metaclass=SingletonMeta): logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!') raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}") - def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, LocalMetaInfo]: + def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]: """ example: local:/path/to/checkpoint boto3:s3://model_weights/0331/120bi + volc:vc://model_weights/0331/120bi Args: path (str): _description_ @@ -507,10 +724,29 @@ class StorageManager(metaclass=SingletonMeta): the proxy may make boto3 unavailable or affect performance." ) self.has_warning = True + elif backend == "volc": + meta_info = get_volc_meta(path, self.tmp_local_folder, async_mode) + backend_key = backend + ":" + meta_info.endpoint + init_args = ( + meta_info.endpoint, + meta_info.region, + ) + if ( + "http_proxy" in os.environ + or "https_proxy" in os.environ + or "HTTP_PROXY" in os.environ + or "HTTPS_PROXY" in os.environ + ): + if not self.has_warning and gpc.is_rank_for_log(): + logger.warning( + "HTTP/HTTPS proxy is detected when using volc, incorrectly setting \ + the proxy may make volc unavailable or affect performance." + ) + self.has_warning = True assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}" - # boto3 backend need special treatment. + # boto3 and volc backend need special treatment. if backend_key not in StorageManager.CLI_DICT: StorageManager.CLI_DICT.update({backend_key: StorageManager.BACKEND_INIT_METHOD[backend](*init_args)}) @@ -527,11 +763,10 @@ class StorageManager(metaclass=SingletonMeta): return meta.client.get_fns(*unpack_nosave_meta(meta)) def save(self, save_path: str, to_save_obj: Any, async_upload=None, **kwargs): - if async_upload is None: async_upload = self.async_mode - if not save_path.startswith("boto3:"): + if not save_path.startswith("boto3:") and not save_path.startswith("volc:"): async_upload = False meta = self._get_client(save_path, async_upload) @@ -554,6 +789,7 @@ class StorageManager(metaclass=SingletonMeta): def load(self, load_path: str, **kwargs) -> Any: self.wait() meta = self._get_client(path=load_path) + return meta.client.load(*unpack_nosave_meta(meta), **kwargs) def delete_obj(self, fp: str): diff --git a/tests/test_utils/test_storage_manager.py b/tests/test_utils/test_storage_manager.py index e5f60c4..949c5ef 100644 --- a/tests/test_utils/test_storage_manager.py +++ b/tests/test_utils/test_storage_manager.py @@ -6,9 +6,9 @@ import torch from internlm.core.context.parallel_context import Config from internlm.initialize.launch import get_config_value from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import - ASYNC_TMP_FOLDER, BOTO_SAVE_PATH, LOCAL_SAVE_PATH, + VOLC_SAVE_PATH, del_tmp_file, init_dist_and_model, reset_singletons, @@ -48,6 +48,22 @@ ckpt_config_list = [ save_folder=LOCAL_SAVE_PATH, test_id=3, ), + # async volc + dict( + enable_save_ckpt=True, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + async_upload=True, + save_folder=VOLC_SAVE_PATH, + test_id=4, + ), + # sync volc + dict( + enable_save_ckpt=True, + async_upload_tmp_folder=None, + async_upload=False, + save_folder=VOLC_SAVE_PATH, + test_id=5, + ), ] @@ -97,6 +113,9 @@ internlm_ckpt_path = [ ("/mnt/ckpt/", "local", "/mnt/ckpt/"), ("./ckpt/", "local", "./ckpt/"), ("s3://oss_bucket/", "boto3", "s3://oss_bucket/"), + ("volc:vc://oss_bucket/", "volc", "vc://oss_bucket/"), + ("volc:oss_bucket/", "volc", "oss_bucket/"), + ("vc://oss_bucket/", "volc", "vc://oss_bucket/"), ]