diff --git a/doc/code-docs/source/checkpoint.rst b/doc/code-docs/source/checkpoint.rst index de0c4cb..a560179 100644 --- a/doc/code-docs/source/checkpoint.rst +++ b/doc/code-docs/source/checkpoint.rst @@ -39,7 +39,7 @@ CheckpointManager load_ckpt_folder=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internlm"), auto_resume=False, # disable auto-resume, internlm will load model checkpoint from the path of 'load_ckpt_folder'. checkpoint_every=CHECKPOINT_EVERY, - async_upload=True, # async ckpt upload. (only work for boto3 and volc ckpt) + async_upload=True, # async ckpt upload. (only work for boto3, volc and oss2 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. ) @@ -69,7 +69,9 @@ InternLM对config中出现的所有存储路径都遵循以下的路径格式约 2. 如果需要使用volc的路径,需要在运行前提前导入 ``VOLC_ACCESS_KEY_ID`` 和 ``VOLC_SECRET_ACCESS_KEY_ID`` 这两个环境变量。 -3. bucket的endpoint一般分为Inside IP和Outside IP,如果可以尽量使用inside IP,会获得更佳的存储速度。 +3. 如果需要使用oss2的路径,需要在运行前提前导入 ``ALI_ACCESS_KEY_ID`` 和 ``ALI_SECRET_ACCESS_KEY_ID`` 这两个环境变量。 + +4. bucket的endpoint一般分为Inside IP和Outside IP,如果可以尽量使用inside IP,会获得更佳的存储速度。 @@ -116,7 +118,7 @@ config.ckpt 中相关的参数: - ``async_upload_tmp_folder``: 异步上传临时存储路径。参数类型 ``str/None``, 默认值为 ``/dev/shm/{JOB_NAME}_tmp_ckpt/``。 -需要注意的是,异步上传功能仅在backend为boto3或volc时才会有效果,bcakend为local时只支持同步存储。 +需要注意的是,异步上传功能仅在backend为非local时才会有效果,bcakend为local时只支持同步存储。 ``async_upload_tmp_folder`` 设置的的原则为尽量设置为计算节点的local目录,这样才可以获得最佳的异步上传速度,一般来说建议为 ``/dev/shm`` 或 ``/nvme`` 下的路径,如果使用同步上传,则该路径可不给。 diff --git a/doc/imgs/ckpt_path_format_CN.png b/doc/imgs/ckpt_path_format_CN.png index 0b21f54..9649dc4 100644 Binary files a/doc/imgs/ckpt_path_format_CN.png and b/doc/imgs/ckpt_path_format_CN.png differ diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 89f8023..21e5ef4 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -25,15 +25,25 @@ from internlm.utils.logger import get_logger try: import boto3 import botocore +except ImportError: + pass + +try: import tos except ImportError: pass +try: + import oss2 +except ImportError: + pass + logger = get_logger(__file__) boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)") volc_url_re = re.compile(r"^(.*?)\.(.*)$") +ali_url_re = re.compile(r"([^/.]+)\.([^/.]+\..+)") MB = 1024**2 @@ -165,6 +175,45 @@ region:{self.region}, local_nvme_path: {self.local_nvme_path}" return meta.client, meta.bucket_name, meta.file_path +class AliMetaInfo: + """Ali meta info for save/load etc.""" + + def __init__( + self, + is_async, + handler: StorageClient, + bucket_name: str, + endpoint: str, + file_path: str, + async_upload_fn: callable, + local_nvme_path=None, + ) -> None: + # all need info. + self.client = handler + self.bucket_name = bucket_name + self.file_path = file_path + # only save need info. + self.local_nvme_path = local_nvme_path + self.is_async = is_async + self.endpoint = endpoint + self.async_upload_fn = async_upload_fn + + def __str__(self) -> str: + return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \ +local_nvme_path: {self.local_nvme_path}" + + @staticmethod + def unpack_ali_save_meta(meta): + if meta.is_async: + return meta.client, meta.file_path, meta.local_nvme_path + else: + return meta.client, meta.file_path + + @staticmethod + def unpack_ali_nosave_meta(meta): + return meta.client, meta.file_path + + class LocalMetaInfo: """Local meta info for save/load etc.""" @@ -182,22 +231,26 @@ class LocalMetaInfo: return (meta.file_path,) -def unpack_save_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]): +def unpack_save_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, AliMetaInfo, LocalMetaInfo]): if isinstance(meta, Boto3MetaInfo): return Boto3MetaInfo.unpack_boto3_save_meta(meta) elif isinstance(meta, VolcMetaInfo): return VolcMetaInfo.unpack_volc_save_meta(meta) + elif isinstance(meta, AliMetaInfo): + return AliMetaInfo.unpack_ali_save_meta(meta) elif isinstance(meta, LocalMetaInfo): return LocalMetaInfo.unpack_local_save_meta(meta) else: raise ValueError(f"unkonwn meta info: {type(meta)}") -def unpack_nosave_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]): +def unpack_nosave_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, AliMetaInfo, LocalMetaInfo]): if isinstance(meta, Boto3MetaInfo): return Boto3MetaInfo.unpack_boto3_nosave_meta(meta) elif isinstance(meta, VolcMetaInfo): return VolcMetaInfo.unpack_volc_nosave_meta(meta) + elif isinstance(meta, AliMetaInfo): + return AliMetaInfo.unpack_ali_nosave_meta(meta) elif isinstance(meta, LocalMetaInfo): return LocalMetaInfo.unpack_local_nosave_meta(meta) else: @@ -221,6 +274,10 @@ def try_get_storage_backend(path: str): if gpc.is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of volc.") return "volc", path + elif path.startswith("ali:"): + if gpc.is_rank_for_log(): + logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of ali.") + return "oss2", path else: sre = path.split(":", maxsplit=1) if len(sre) == 1: @@ -500,6 +557,99 @@ class VolcClient(StorageClient): raise NotImplementedError("volc not support delete_obj") +class AliClient(StorageClient): + """ + AliClient + """ + + def __init__( + self, + bucket_name: str, + endpoint: str, + ) -> None: + """Ali object/file storage management class + + Args: + access_key (str): Ali access key ID. + secret_key (str): Ali secret access key. + endpoint (str): Ali tos endpoint. + region (str): Ali tos region. + + """ + super().__init__(oss2) + + try: + access_key = os.environ["ALI_ACCESS_KEY_ID"] + secret_key = os.environ["ALI_SECRET_ACCESS_KEY_ID"] + except KeyError as exc: + raise RuntimeError( + "Please set 'ALI_ACCESS_KEY_ID' and 'ALI_SECRET_ACCESS_KEY_ID'", + "using environment variable!", + ) from exc + + self.auth = self.handler.Auth(access_key, secret_key) + self.client = self.handler.Bucket(self.auth, endpoint, bucket_name) + + @staticmethod + def sync_upload_fileobj(handler, fp: str, saved_obj=None, **kwargs): + assert saved_obj is not None, "saved_obj is None!" + try: + with io.BytesIO() as f: + torch.save(saved_obj, f, **kwargs) + f.seek(0) + handler.client.put_object(fp, f) + except Exception as e: + raise e + + @staticmethod + def load(handler, fp: str, **kwargs) -> Dict: + """ + Args: + fp (str): Path to save, eg. ali://opennlplab/model_weights/xxx/ddd.pt + """ + try: + object_stream = handler.client.get_object(fp) + buffer = io.BytesIO(object_stream.read()) + states = torch.load(buffer, **kwargs) + except Exception as e: + raise e + + return states + + @staticmethod + def assert_fp_exists(handler, fp: str): # pylint: disable=W0613 + assert len(list(handler.handler.ObjectIteratorV2(handler.client, prefix=fp))) > 0, fp + + @staticmethod + def is_fp_exists(handler, fp: str): # pylint: disable=W0613 + return len(list(handler.handler.ObjectIteratorV2(handler.client, prefix=fp))) > 0 + + @staticmethod + def get_fns(handler, fp: str): + if AliClient.is_fp_exists(handler, fp): + folder_name_list = [] + for obj in handler.handler.ObjectIteratorV2(handler.client, prefix=fp): + folder_name_list.append(obj.key.split("/")[-1]) + + return list(set(folder_name_list)) + + else: + if gpc.is_rank_for_log(): + logger.warning(f"'{fp}' not found!") + return None + + @staticmethod + def async_upload_fileobj(handler, fp: str, local_nvme_path: str): + try: + handler.client.put_object_from_file(fp, local_nvme_path) + except Exception as e: + raise e + + @staticmethod + def delete_obj(handler, fp: str): + raise NotImplementedError("ali not support delete_obj") + + class LocalClient(StorageClient): """ Storage Client for local NFS. @@ -582,9 +732,8 @@ def get_volc_meta(fp: str, tmp_local_folder: str, is_async: bool) -> VolcMetaInf match = volc_url_re.match(parts[0]) assert match is not None, f"url '{fp}' is not a valid volc url" bucket_name, endpoint = match.group(1), match.group(2) - temp_part = endpoint.split(".") - endpoint = ".".join(temp_part[1:]) - region = temp_part[1].split("-") + region = endpoint.split(".") + region = region[0].split("-") region = "-".join(region[1:]) if is_async: @@ -603,8 +752,32 @@ def get_volc_meta(fp: str, tmp_local_folder: str, is_async: bool) -> VolcMetaInf ) +def get_ali_meta(fp: str, tmp_local_folder: str, is_async: bool) -> AliMetaInfo: + assert fp.startswith("ali://"), f"Path '{fp}' is not a ali url" + parts = fp.lstrip("ali://").split(os.path.sep) + match = ali_url_re.match(parts[0]) + assert match is not None, f"url '{fp}' is not a valid ali url" + bucket_name, endpoint = match.group(1), match.group(2) + + if is_async: + tmp_step_file = get_tmp_file_name(tmp_local_folder, fp) + else: + tmp_step_file = None + return AliMetaInfo( + is_async=is_async, + handler=None, + bucket_name=bucket_name, + endpoint=endpoint, + file_path=os.path.sep.join(parts[1:]), + async_upload_fn=AliClient.async_upload_fileobj, + local_nvme_path=tmp_step_file, + ) + + def get_local_meta(fp: str) -> LocalMetaInfo: - assert not fp.startswith("s3://") and not fp.startswith("vc://"), f"Path '{fp}' is not a local path" + assert ( + not fp.startswith("s3://") and not fp.startswith("vc://") and not fp.startswith("ali://") + ), f"Path '{fp}' is not a local path" return LocalMetaInfo(fp) @@ -645,11 +818,12 @@ class StorageManager(metaclass=SingletonMeta): TODO: add a thread to poll the asynchronous storage state. """ - BACKEND_TYPE = {"boto3", "local", "volc"} + BACKEND_TYPE = {"boto3", "local", "volc", "oss2"} BACKEND_INIT_METHOD = { "boto3": Boto3Client, "local": LocalClient, "volc": VolcClient, + "oss2": AliClient, } CLI_DICT = {} @@ -692,12 +866,15 @@ class StorageManager(metaclass=SingletonMeta): logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!') raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}") - def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]: + def _get_client( + self, path: str, async_mode: bool = False + ) -> Union[Boto3MetaInfo, VolcMetaInfo, AliMetaInfo, LocalMetaInfo]: """ example: local:/path/to/checkpoint boto3:s3://model_weights/0331/120bi volc:vc://model_weights/0331/120bi + oss2:ali://model_weights/0331/120bi Args: path (str): _description_ @@ -743,10 +920,29 @@ class StorageManager(metaclass=SingletonMeta): the proxy may make volc unavailable or affect performance." ) self.has_warning = True + elif backend == "oss2": + meta_info = get_ali_meta(path, self.tmp_local_folder, async_mode) + backend_key = backend + ":" + meta_info.endpoint + init_args = ( + meta_info.bucket_name, + meta_info.endpoint, + ) + if ( + "http_proxy" in os.environ + or "https_proxy" in os.environ + or "HTTP_PROXY" in os.environ + or "HTTPS_PROXY" in os.environ + ): + if not self.has_warning and gpc.is_rank_for_log(): + logger.warning( + "HTTP/HTTPS proxy is detected when using oss2, incorrectly setting \ + the proxy may make oss2 unavailable or affect performance." + ) + self.has_warning = True assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}" - # boto3 and volc backend need special treatment. + # boto3, volc and oss2 backend need special treatment. if backend_key not in StorageManager.CLI_DICT: StorageManager.CLI_DICT.update({backend_key: StorageManager.BACKEND_INIT_METHOD[backend](*init_args)}) @@ -766,7 +962,11 @@ class StorageManager(metaclass=SingletonMeta): if async_upload is None: async_upload = self.async_mode - if not save_path.startswith("boto3:") and not save_path.startswith("volc:"): + if ( + not save_path.startswith("boto3:") + and not save_path.startswith("volc:") + and not save_path.startswith("oss2:") + ): async_upload = False meta = self._get_client(save_path, async_upload) diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index 5d6d7da..56e7b21 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -22,6 +22,9 @@ BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" VOLC_SAVE_PATH = f"volc:vc://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" VOLC_SAVE_PATH_NO_PRFIX = f"vc://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" +ALI_SAVE_PATH = f"oss2:ali://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" +ALI_SAVE_PATH_NO_PRFIX = f"ali://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" + ASYNC_TMP_FOLDER = "./async_tmp_folder" @@ -197,3 +200,14 @@ def del_tmp_file(): print(presults, flush=True) except: # noqa # pylint: disable=bare-except pass + + try: + cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + ALI_SAVE_PATH_NO_PRFIX + " / " + with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output: + results, presults = "", "" + for line in iter(output.stdout.readline, b""): + results += str(line.rstrip()) + presults += line.rstrip().decode() + "\n" + print(presults, flush=True) + except: # noqa # pylint: disable=bare-except + pass diff --git a/tests/test_utils/test_storage_manager.py b/tests/test_utils/test_storage_manager.py index e96374e..e102ca1 100644 --- a/tests/test_utils/test_storage_manager.py +++ b/tests/test_utils/test_storage_manager.py @@ -6,6 +6,7 @@ import torch from internlm.core.context.parallel_context import Config from internlm.initialize.launch import get_config_value from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import + ALI_SAVE_PATH, BOTO_SAVE_PATH, LOCAL_SAVE_PATH, VOLC_SAVE_PATH, @@ -63,6 +64,22 @@ ckpt_config_list = [ save_folder=VOLC_SAVE_PATH, test_id=6, ), + # async ali + dict( + enable_save_ckpt=True, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + async_upload=True, + save_folder=ALI_SAVE_PATH, + test_id=7, + ), + # sync ali + dict( + enable_save_ckpt=True, + async_upload_tmp_folder=None, + async_upload=False, + save_folder=ALI_SAVE_PATH, + test_id=8, + ), ] @@ -89,7 +106,7 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg ckpt_config = Config(ckpt_config) if os.environ.get("OSS_BUCKET_NAME") is None: if ckpt_config.test_id > 2: - print("Pass boto3 and volc", flush=True) + print("Pass boto3, volc and ali", flush=True) return enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False) @@ -120,6 +137,9 @@ internlm_ckpt_path = [ ("volc:vc://oss_bucket/", "volc", "vc://oss_bucket/"), ("volc:oss_bucket/", "volc", "oss_bucket/"), ("vc://oss_bucket/", "volc", "vc://oss_bucket/"), + ("oss2:ali://oss_bucket/", "oss2", "ali://oss_bucket/"), + ("oss2:oss_bucket/", "oss2", "oss_bucket/"), + ("ali://oss_bucket/", "oss2", "ali://oss_bucket/"), ]