mirror of https://github.com/InternLM/InternLM
feat(storage): support volc oss ckpt saving (#397)
* feat: support volc tos * feat: support volc osspull/418/head
parent
9a731b6e9b
commit
71a0388b87
|
@ -39,7 +39,7 @@ CheckpointManager
|
||||||
load_ckpt_folder=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internlm"),
|
load_ckpt_folder=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internlm"),
|
||||||
auto_resume=False, # disable auto-resume, internlm will load model checkpoint from the path of 'load_ckpt_folder'.
|
auto_resume=False, # disable auto-resume, internlm will load model checkpoint from the path of 'load_ckpt_folder'.
|
||||||
checkpoint_every=CHECKPOINT_EVERY,
|
checkpoint_every=CHECKPOINT_EVERY,
|
||||||
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
async_upload=True, # async ckpt upload. (only work for boto3 and volc ckpt)
|
||||||
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
||||||
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
||||||
)
|
)
|
||||||
|
@ -67,7 +67,9 @@ InternLM对config中出现的所有存储路径都遵循以下的路径格式约
|
||||||
|
|
||||||
1. 如果需要使用boto3的路径,需要在运行前提前导入 ``S3_ACCESS_KEY_ID`` 和 ``S3_SECRET_ACCESS_KEY_ID`` 这两个环境变量。
|
1. 如果需要使用boto3的路径,需要在运行前提前导入 ``S3_ACCESS_KEY_ID`` 和 ``S3_SECRET_ACCESS_KEY_ID`` 这两个环境变量。
|
||||||
|
|
||||||
2. bucket的endpoint一般分为Inside IP和Outside IP,如果可以尽量使用inside IP,会获得更佳的存储速度。
|
2. 如果需要使用volc的路径,需要在运行前提前导入 ``VOLC_ACCESS_KEY_ID`` 和 ``VOLC_SECRET_ACCESS_KEY_ID`` 这两个环境变量。
|
||||||
|
|
||||||
|
3. bucket的endpoint一般分为Inside IP和Outside IP,如果可以尽量使用inside IP,会获得更佳的存储速度。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -114,7 +116,7 @@ config.ckpt 中相关的参数:
|
||||||
|
|
||||||
- ``async_upload_tmp_folder``: 异步上传临时存储路径。参数类型 ``str/None``, 默认值为 ``/dev/shm/{JOB_NAME}_tmp_ckpt/``。
|
- ``async_upload_tmp_folder``: 异步上传临时存储路径。参数类型 ``str/None``, 默认值为 ``/dev/shm/{JOB_NAME}_tmp_ckpt/``。
|
||||||
|
|
||||||
需要注意的是,异步上传功能仅在backend为boto3时才会有效果,bcakend为local时只支持同步存储。
|
需要注意的是,异步上传功能仅在backend为boto3或volc时才会有效果,bcakend为local时只支持同步存储。
|
||||||
|
|
||||||
``async_upload_tmp_folder`` 设置的的原则为尽量设置为计算节点的local目录,这样才可以获得最佳的异步上传速度,一般来说建议为 ``/dev/shm`` 或 ``/nvme`` 下的路径,如果使用同步上传,则该路径可不给。
|
``async_upload_tmp_folder`` 设置的的原则为尽量设置为计算节点的local目录,这样才可以获得最佳的异步上传速度,一般来说建议为 ``/dev/shm`` 或 ``/nvme`` 下的路径,如果使用同步上传,则该路径可不给。
|
||||||
|
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 153 KiB After Width: | Height: | Size: 212 KiB |
|
@ -25,6 +25,7 @@ from internlm.utils.logger import get_logger
|
||||||
try:
|
try:
|
||||||
import boto3
|
import boto3
|
||||||
import botocore
|
import botocore
|
||||||
|
import tos
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -32,6 +33,7 @@ except ImportError:
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
|
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
|
||||||
|
volc_url_re = re.compile(r"^(.*?)\.(.*)$")
|
||||||
|
|
||||||
MB = 1024**2
|
MB = 1024**2
|
||||||
|
|
||||||
|
@ -122,6 +124,47 @@ local_nvme_path: {self.local_nvme_path}"
|
||||||
return meta.client, meta.bucket_name, meta.file_path
|
return meta.client, meta.bucket_name, meta.file_path
|
||||||
|
|
||||||
|
|
||||||
|
class VolcMetaInfo:
|
||||||
|
"""Volc meta info for save/load etc."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
is_async,
|
||||||
|
handler: StorageClient,
|
||||||
|
bucket_name: str,
|
||||||
|
endpoint: str,
|
||||||
|
region: str,
|
||||||
|
file_path: str,
|
||||||
|
async_upload_fn: callable,
|
||||||
|
local_nvme_path=None,
|
||||||
|
) -> None:
|
||||||
|
# all need info.
|
||||||
|
self.client = handler
|
||||||
|
self.bucket_name = bucket_name
|
||||||
|
self.file_path = file_path
|
||||||
|
# only save need info.
|
||||||
|
self.local_nvme_path = local_nvme_path
|
||||||
|
self.is_async = is_async
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.region = region
|
||||||
|
self.async_upload_fn = async_upload_fn
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
|
||||||
|
region:{self.region}, local_nvme_path: {self.local_nvme_path}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unpack_volc_save_meta(meta):
|
||||||
|
if meta.is_async:
|
||||||
|
return meta.client, meta.bucket_name, meta.file_path, meta.local_nvme_path
|
||||||
|
else:
|
||||||
|
return meta.client, meta.bucket_name, meta.file_path
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unpack_volc_nosave_meta(meta):
|
||||||
|
return meta.client, meta.bucket_name, meta.file_path
|
||||||
|
|
||||||
|
|
||||||
class LocalMetaInfo:
|
class LocalMetaInfo:
|
||||||
"""Local meta info for save/load etc."""
|
"""Local meta info for save/load etc."""
|
||||||
|
|
||||||
|
@ -139,18 +182,22 @@ class LocalMetaInfo:
|
||||||
return (meta.file_path,)
|
return (meta.file_path,)
|
||||||
|
|
||||||
|
|
||||||
def unpack_save_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
|
def unpack_save_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]):
|
||||||
if isinstance(meta, Boto3MetaInfo):
|
if isinstance(meta, Boto3MetaInfo):
|
||||||
return Boto3MetaInfo.unpack_boto3_save_meta(meta)
|
return Boto3MetaInfo.unpack_boto3_save_meta(meta)
|
||||||
|
elif isinstance(meta, VolcMetaInfo):
|
||||||
|
return VolcMetaInfo.unpack_volc_save_meta(meta)
|
||||||
elif isinstance(meta, LocalMetaInfo):
|
elif isinstance(meta, LocalMetaInfo):
|
||||||
return LocalMetaInfo.unpack_local_save_meta(meta)
|
return LocalMetaInfo.unpack_local_save_meta(meta)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unkonwn meta info: {type(meta)}")
|
raise ValueError(f"unkonwn meta info: {type(meta)}")
|
||||||
|
|
||||||
|
|
||||||
def unpack_nosave_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
|
def unpack_nosave_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]):
|
||||||
if isinstance(meta, Boto3MetaInfo):
|
if isinstance(meta, Boto3MetaInfo):
|
||||||
return Boto3MetaInfo.unpack_boto3_nosave_meta(meta)
|
return Boto3MetaInfo.unpack_boto3_nosave_meta(meta)
|
||||||
|
elif isinstance(meta, VolcMetaInfo):
|
||||||
|
return VolcMetaInfo.unpack_volc_nosave_meta(meta)
|
||||||
elif isinstance(meta, LocalMetaInfo):
|
elif isinstance(meta, LocalMetaInfo):
|
||||||
return LocalMetaInfo.unpack_local_nosave_meta(meta)
|
return LocalMetaInfo.unpack_local_nosave_meta(meta)
|
||||||
else:
|
else:
|
||||||
|
@ -170,6 +217,10 @@ def try_get_storage_backend(path: str):
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
|
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
|
||||||
return "boto3", path
|
return "boto3", path
|
||||||
|
elif path.startswith("vc:"):
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of volc.")
|
||||||
|
return "volc", path
|
||||||
else:
|
else:
|
||||||
sre = path.split(":", maxsplit=1)
|
sre = path.split(":", maxsplit=1)
|
||||||
if len(sre) == 1:
|
if len(sre) == 1:
|
||||||
|
@ -312,6 +363,143 @@ class Boto3Client(StorageClient):
|
||||||
raise NotImplementedError("boto3 not support delete_obj")
|
raise NotImplementedError("boto3 not support delete_obj")
|
||||||
|
|
||||||
|
|
||||||
|
class VolcClient(StorageClient):
|
||||||
|
"""
|
||||||
|
VolcClient
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
region: str,
|
||||||
|
) -> None:
|
||||||
|
"""Volc object/file storage management class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
access_key (str): Volc access key ID.
|
||||||
|
secret_key (str): Volc secret access key.
|
||||||
|
endpoint (str): Volc tos endpoint.
|
||||||
|
region (str): Volc tos region.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(tos)
|
||||||
|
|
||||||
|
try:
|
||||||
|
access_key = os.environ["VOLC_ACCESS_KEY_ID"]
|
||||||
|
secret_key = os.environ["VOLC_SECRET_ACCESS_KEY_ID"]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Please set 'VOLC_ACCESS_KEY_ID' and 'VOLC_SECRET_ACCESS_KEY_ID'",
|
||||||
|
"using environment variable!",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
self.client = self.handler.TosClientV2(access_key, secret_key, endpoint, region)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sync_upload_fileobj(handler, bucket_name: str, fp: str, saved_obj=None, **kwargs):
|
||||||
|
assert saved_obj is not None, "saved_obj is None!"
|
||||||
|
try:
|
||||||
|
with io.BytesIO() as f:
|
||||||
|
torch.save(saved_obj, f, **kwargs)
|
||||||
|
f.seek(0)
|
||||||
|
handler.client.put_object(bucket_name, fp, content=f)
|
||||||
|
except handler.handler.exceptions.TosClientError as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}"
|
||||||
|
) from exc
|
||||||
|
except handler.handler.exceptions.TosServerError as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Volc Network Error: fail with server error, code: {exec.code}",
|
||||||
|
f"error with request id: {exec.request_id}",
|
||||||
|
f"error with message: {exec.message}",
|
||||||
|
f"error with http code: {exec.status_code}",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(handler, bucket_name: str, fp: str, **kwargs) -> Dict:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
fp (str): Path to save, eg. vc://opennlplab/model_weights/xxx/ddd.pt
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
object_stream = handler.client.get_object(bucket_name, fp)
|
||||||
|
buffer = io.BytesIO(object_stream.read())
|
||||||
|
states = torch.load(buffer, **kwargs)
|
||||||
|
except handler.handler.exceptions.TosClientError as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}"
|
||||||
|
) from exc
|
||||||
|
except handler.handler.exceptions.TosServerError as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Volc Network Error: fail with server error, code: {exec.code}",
|
||||||
|
f"error with request id: {exec.request_id}",
|
||||||
|
f"error with message: {exec.message}",
|
||||||
|
f"error with http code: {exec.status_code}",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
return states
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def assert_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
|
||||||
|
assert len(list(handler.client.list_objects_type2(bucket_name, prefix=fp).contents)) > 0, fp
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
|
||||||
|
re = handler.client.list_objects_type2(bucket_name, prefix=fp)
|
||||||
|
if hasattr(re, "contents"):
|
||||||
|
return len(list(re.contents)) > 0
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_fns(handler, bucket_name: str, fp: str):
|
||||||
|
if VolcClient.is_fp_exists(handler, bucket_name, fp):
|
||||||
|
folder_name_list = []
|
||||||
|
result = handler.client.list_objects_type2(bucket_name, prefix=fp)
|
||||||
|
if hasattr(result, "contents"):
|
||||||
|
for iterm in result.contents:
|
||||||
|
pth = iterm.key
|
||||||
|
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
|
||||||
|
|
||||||
|
while result.is_truncated:
|
||||||
|
result = handler.client.list_objects_type2(
|
||||||
|
bucket_name, prefix=fp, continuation_token=result.next_continuation_token
|
||||||
|
)
|
||||||
|
if hasattr(result, "contents"):
|
||||||
|
for iterm in result.contents:
|
||||||
|
pth = iterm.key
|
||||||
|
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
|
||||||
|
|
||||||
|
return list(set(folder_name_list))
|
||||||
|
|
||||||
|
else:
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.warning(f"'{fp}' not found!")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
|
||||||
|
try:
|
||||||
|
handler.client.put_object_from_file(bucket_name, fp, local_nvme_path)
|
||||||
|
except handler.handler.exceptions.TosClientError as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}"
|
||||||
|
) from exc
|
||||||
|
except handler.handler.exceptions.TosServerError as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Volc Network Error: fail with server error, code: {exec.code}",
|
||||||
|
f"error with request id: {exec.request_id}",
|
||||||
|
f"error with message: {exec.message}",
|
||||||
|
f"error with http code: {exec.status_code}",
|
||||||
|
) from exc
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_obj(handler, fp: str):
|
||||||
|
raise NotImplementedError("volc not support delete_obj")
|
||||||
|
|
||||||
|
|
||||||
class LocalClient(StorageClient):
|
class LocalClient(StorageClient):
|
||||||
"""
|
"""
|
||||||
Storage Client for local NFS.
|
Storage Client for local NFS.
|
||||||
|
@ -388,8 +576,35 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_volc_meta(fp: str, tmp_local_folder: str, is_async: bool) -> VolcMetaInfo:
|
||||||
|
assert fp.startswith("vc://"), f"Path '{fp}' is not a volc url"
|
||||||
|
parts = fp.lstrip("vc://").split(os.path.sep)
|
||||||
|
match = volc_url_re.match(parts[0])
|
||||||
|
assert match is not None, f"url '{fp}' is not a valid volc url"
|
||||||
|
bucket_name, endpoint = match.group(1), match.group(2)
|
||||||
|
temp_part = endpoint.split(".")
|
||||||
|
endpoint = ".".join(temp_part[1:])
|
||||||
|
region = temp_part[1].split("-")
|
||||||
|
region = "-".join(region[1:])
|
||||||
|
|
||||||
|
if is_async:
|
||||||
|
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
|
||||||
|
else:
|
||||||
|
tmp_step_file = None
|
||||||
|
return VolcMetaInfo(
|
||||||
|
is_async=is_async,
|
||||||
|
handler=None,
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
endpoint=endpoint,
|
||||||
|
region=region,
|
||||||
|
file_path=os.path.sep.join(parts[1:]),
|
||||||
|
async_upload_fn=VolcClient.async_upload_fileobj,
|
||||||
|
local_nvme_path=tmp_step_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_local_meta(fp: str) -> LocalMetaInfo:
|
def get_local_meta(fp: str) -> LocalMetaInfo:
|
||||||
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
|
assert not fp.startswith("s3://") and not fp.startswith("vc://"), f"Path '{fp}' is not a local path"
|
||||||
return LocalMetaInfo(fp)
|
return LocalMetaInfo(fp)
|
||||||
|
|
||||||
|
|
||||||
|
@ -430,10 +645,11 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
TODO: add a thread to poll the asynchronous storage state.
|
TODO: add a thread to poll the asynchronous storage state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BACKEND_TYPE = {"boto3", "local"}
|
BACKEND_TYPE = {"boto3", "local", "volc"}
|
||||||
BACKEND_INIT_METHOD = {
|
BACKEND_INIT_METHOD = {
|
||||||
"boto3": Boto3Client,
|
"boto3": Boto3Client,
|
||||||
"local": LocalClient,
|
"local": LocalClient,
|
||||||
|
"volc": VolcClient,
|
||||||
}
|
}
|
||||||
CLI_DICT = {}
|
CLI_DICT = {}
|
||||||
|
|
||||||
|
@ -476,11 +692,12 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
|
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
|
||||||
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
|
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
|
||||||
|
|
||||||
def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, LocalMetaInfo]:
|
def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]:
|
||||||
"""
|
"""
|
||||||
example:
|
example:
|
||||||
local:/path/to/checkpoint
|
local:/path/to/checkpoint
|
||||||
boto3:s3://model_weights/0331/120bi
|
boto3:s3://model_weights/0331/120bi
|
||||||
|
volc:vc://model_weights/0331/120bi
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): _description_
|
path (str): _description_
|
||||||
|
@ -507,10 +724,29 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
the proxy may make boto3 unavailable or affect performance."
|
the proxy may make boto3 unavailable or affect performance."
|
||||||
)
|
)
|
||||||
self.has_warning = True
|
self.has_warning = True
|
||||||
|
elif backend == "volc":
|
||||||
|
meta_info = get_volc_meta(path, self.tmp_local_folder, async_mode)
|
||||||
|
backend_key = backend + ":" + meta_info.endpoint
|
||||||
|
init_args = (
|
||||||
|
meta_info.endpoint,
|
||||||
|
meta_info.region,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"http_proxy" in os.environ
|
||||||
|
or "https_proxy" in os.environ
|
||||||
|
or "HTTP_PROXY" in os.environ
|
||||||
|
or "HTTPS_PROXY" in os.environ
|
||||||
|
):
|
||||||
|
if not self.has_warning and gpc.is_rank_for_log():
|
||||||
|
logger.warning(
|
||||||
|
"HTTP/HTTPS proxy is detected when using volc, incorrectly setting \
|
||||||
|
the proxy may make volc unavailable or affect performance."
|
||||||
|
)
|
||||||
|
self.has_warning = True
|
||||||
|
|
||||||
assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}"
|
assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}"
|
||||||
|
|
||||||
# boto3 backend need special treatment.
|
# boto3 and volc backend need special treatment.
|
||||||
if backend_key not in StorageManager.CLI_DICT:
|
if backend_key not in StorageManager.CLI_DICT:
|
||||||
StorageManager.CLI_DICT.update({backend_key: StorageManager.BACKEND_INIT_METHOD[backend](*init_args)})
|
StorageManager.CLI_DICT.update({backend_key: StorageManager.BACKEND_INIT_METHOD[backend](*init_args)})
|
||||||
|
|
||||||
|
@ -527,11 +763,10 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
return meta.client.get_fns(*unpack_nosave_meta(meta))
|
return meta.client.get_fns(*unpack_nosave_meta(meta))
|
||||||
|
|
||||||
def save(self, save_path: str, to_save_obj: Any, async_upload=None, **kwargs):
|
def save(self, save_path: str, to_save_obj: Any, async_upload=None, **kwargs):
|
||||||
|
|
||||||
if async_upload is None:
|
if async_upload is None:
|
||||||
async_upload = self.async_mode
|
async_upload = self.async_mode
|
||||||
|
|
||||||
if not save_path.startswith("boto3:"):
|
if not save_path.startswith("boto3:") and not save_path.startswith("volc:"):
|
||||||
async_upload = False
|
async_upload = False
|
||||||
|
|
||||||
meta = self._get_client(save_path, async_upload)
|
meta = self._get_client(save_path, async_upload)
|
||||||
|
@ -554,6 +789,7 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
def load(self, load_path: str, **kwargs) -> Any:
|
def load(self, load_path: str, **kwargs) -> Any:
|
||||||
self.wait()
|
self.wait()
|
||||||
meta = self._get_client(path=load_path)
|
meta = self._get_client(path=load_path)
|
||||||
|
|
||||||
return meta.client.load(*unpack_nosave_meta(meta), **kwargs)
|
return meta.client.load(*unpack_nosave_meta(meta), **kwargs)
|
||||||
|
|
||||||
def delete_obj(self, fp: str):
|
def delete_obj(self, fp: str):
|
||||||
|
|
|
@ -6,9 +6,9 @@ import torch
|
||||||
from internlm.core.context.parallel_context import Config
|
from internlm.core.context.parallel_context import Config
|
||||||
from internlm.initialize.launch import get_config_value
|
from internlm.initialize.launch import get_config_value
|
||||||
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
||||||
ASYNC_TMP_FOLDER,
|
|
||||||
BOTO_SAVE_PATH,
|
BOTO_SAVE_PATH,
|
||||||
LOCAL_SAVE_PATH,
|
LOCAL_SAVE_PATH,
|
||||||
|
VOLC_SAVE_PATH,
|
||||||
del_tmp_file,
|
del_tmp_file,
|
||||||
init_dist_and_model,
|
init_dist_and_model,
|
||||||
reset_singletons,
|
reset_singletons,
|
||||||
|
@ -48,6 +48,22 @@ ckpt_config_list = [
|
||||||
save_folder=LOCAL_SAVE_PATH,
|
save_folder=LOCAL_SAVE_PATH,
|
||||||
test_id=3,
|
test_id=3,
|
||||||
),
|
),
|
||||||
|
# async volc
|
||||||
|
dict(
|
||||||
|
enable_save_ckpt=True,
|
||||||
|
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||||
|
async_upload=True,
|
||||||
|
save_folder=VOLC_SAVE_PATH,
|
||||||
|
test_id=4,
|
||||||
|
),
|
||||||
|
# sync volc
|
||||||
|
dict(
|
||||||
|
enable_save_ckpt=True,
|
||||||
|
async_upload_tmp_folder=None,
|
||||||
|
async_upload=False,
|
||||||
|
save_folder=VOLC_SAVE_PATH,
|
||||||
|
test_id=5,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,6 +113,9 @@ internlm_ckpt_path = [
|
||||||
("/mnt/ckpt/", "local", "/mnt/ckpt/"),
|
("/mnt/ckpt/", "local", "/mnt/ckpt/"),
|
||||||
("./ckpt/", "local", "./ckpt/"),
|
("./ckpt/", "local", "./ckpt/"),
|
||||||
("s3://oss_bucket/", "boto3", "s3://oss_bucket/"),
|
("s3://oss_bucket/", "boto3", "s3://oss_bucket/"),
|
||||||
|
("volc:vc://oss_bucket/", "volc", "vc://oss_bucket/"),
|
||||||
|
("volc:oss_bucket/", "volc", "oss_bucket/"),
|
||||||
|
("vc://oss_bucket/", "volc", "vc://oss_bucket/"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue