feat(storage): support ali oss ckpt saving (#439)

pull/464/head
jiaxingli 2023-10-27 22:32:10 +08:00 committed by GitHub
parent e6d8ebc3e5
commit 4995060d84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 250 additions and 14 deletions

View File

@ -39,7 +39,7 @@ CheckpointManager
load_ckpt_folder=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internlm"),
auto_resume=False, # disable auto-resume, internlm will load model checkpoint from the path of 'load_ckpt_folder'.
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 and volc ckpt)
async_upload=True, # async ckpt upload. (only work for boto3, volc and oss2 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)
@ -69,7 +69,9 @@ InternLM对config中出现的所有存储路径都遵循以下的路径格式约
2. 如果需要使用volc的路径需要在运行前提前导入 ``VOLC_ACCESS_KEY_ID````VOLC_SECRET_ACCESS_KEY_ID`` 这两个环境变量。
3. bucket的endpoint一般分为Inside IP和Outside IP如果可以尽量使用inside IP会获得更佳的存储速度。
3. 如果需要使用oss2的路径需要在运行前提前导入 ``ALI_ACCESS_KEY_ID````ALI_SECRET_ACCESS_KEY_ID`` 这两个环境变量。
4. bucket的endpoint一般分为Inside IP和Outside IP如果可以尽量使用inside IP会获得更佳的存储速度。
@ -116,7 +118,7 @@ config.ckpt 中相关的参数:
- ``async_upload_tmp_folder``: 异步上传临时存储路径。参数类型 ``str/None``, 默认值为 ``/dev/shm/{JOB_NAME}_tmp_ckpt/``
需要注意的是异步上传功能仅在backend为boto3或volc时才会有效果bcakend为local时只支持同步存储。
需要注意的是异步上传功能仅在backend为非local时才会有效果bcakend为local时只支持同步存储。
``async_upload_tmp_folder`` 设置的的原则为尽量设置为计算节点的local目录这样才可以获得最佳的异步上传速度一般来说建议为 ``/dev/shm````/nvme`` 下的路径,如果使用同步上传,则该路径可不给。

Binary file not shown.

Before

Width:  |  Height:  |  Size: 212 KiB

After

Width:  |  Height:  |  Size: 282 KiB

View File

@ -25,15 +25,25 @@ from internlm.utils.logger import get_logger
try:
import boto3
import botocore
except ImportError:
pass
try:
import tos
except ImportError:
pass
try:
import oss2
except ImportError:
pass
logger = get_logger(__file__)
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
volc_url_re = re.compile(r"^(.*?)\.(.*)$")
ali_url_re = re.compile(r"([^/.]+)\.([^/.]+\..+)")
MB = 1024**2
@ -165,6 +175,45 @@ region:{self.region}, local_nvme_path: {self.local_nvme_path}"
return meta.client, meta.bucket_name, meta.file_path
class AliMetaInfo:
"""Ali meta info for save/load etc."""
def __init__(
self,
is_async,
handler: StorageClient,
bucket_name: str,
endpoint: str,
file_path: str,
async_upload_fn: callable,
local_nvme_path=None,
) -> None:
# all need info.
self.client = handler
self.bucket_name = bucket_name
self.file_path = file_path
# only save need info.
self.local_nvme_path = local_nvme_path
self.is_async = is_async
self.endpoint = endpoint
self.async_upload_fn = async_upload_fn
def __str__(self) -> str:
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
local_nvme_path: {self.local_nvme_path}"
@staticmethod
def unpack_ali_save_meta(meta):
if meta.is_async:
return meta.client, meta.file_path, meta.local_nvme_path
else:
return meta.client, meta.file_path
@staticmethod
def unpack_ali_nosave_meta(meta):
return meta.client, meta.file_path
class LocalMetaInfo:
"""Local meta info for save/load etc."""
@ -182,22 +231,26 @@ class LocalMetaInfo:
return (meta.file_path,)
def unpack_save_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]):
def unpack_save_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, AliMetaInfo, LocalMetaInfo]):
if isinstance(meta, Boto3MetaInfo):
return Boto3MetaInfo.unpack_boto3_save_meta(meta)
elif isinstance(meta, VolcMetaInfo):
return VolcMetaInfo.unpack_volc_save_meta(meta)
elif isinstance(meta, AliMetaInfo):
return AliMetaInfo.unpack_ali_save_meta(meta)
elif isinstance(meta, LocalMetaInfo):
return LocalMetaInfo.unpack_local_save_meta(meta)
else:
raise ValueError(f"unkonwn meta info: {type(meta)}")
def unpack_nosave_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]):
def unpack_nosave_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, AliMetaInfo, LocalMetaInfo]):
if isinstance(meta, Boto3MetaInfo):
return Boto3MetaInfo.unpack_boto3_nosave_meta(meta)
elif isinstance(meta, VolcMetaInfo):
return VolcMetaInfo.unpack_volc_nosave_meta(meta)
elif isinstance(meta, AliMetaInfo):
return AliMetaInfo.unpack_ali_nosave_meta(meta)
elif isinstance(meta, LocalMetaInfo):
return LocalMetaInfo.unpack_local_nosave_meta(meta)
else:
@ -221,6 +274,10 @@ def try_get_storage_backend(path: str):
if gpc.is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of volc.")
return "volc", path
elif path.startswith("ali:"):
if gpc.is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of ali.")
return "oss2", path
else:
sre = path.split(":", maxsplit=1)
if len(sre) == 1:
@ -500,6 +557,99 @@ class VolcClient(StorageClient):
raise NotImplementedError("volc not support delete_obj")
class AliClient(StorageClient):
"""
AliClient
"""
def __init__(
self,
bucket_name: str,
endpoint: str,
) -> None:
"""Ali object/file storage management class
Args:
access_key (str): Ali access key ID.
secret_key (str): Ali secret access key.
endpoint (str): Ali tos endpoint.
region (str): Ali tos region.
"""
super().__init__(oss2)
try:
access_key = os.environ["ALI_ACCESS_KEY_ID"]
secret_key = os.environ["ALI_SECRET_ACCESS_KEY_ID"]
except KeyError as exc:
raise RuntimeError(
"Please set 'ALI_ACCESS_KEY_ID' and 'ALI_SECRET_ACCESS_KEY_ID'",
"using environment variable!",
) from exc
self.auth = self.handler.Auth(access_key, secret_key)
self.client = self.handler.Bucket(self.auth, endpoint, bucket_name)
@staticmethod
def sync_upload_fileobj(handler, fp: str, saved_obj=None, **kwargs):
assert saved_obj is not None, "saved_obj is None!"
try:
with io.BytesIO() as f:
torch.save(saved_obj, f, **kwargs)
f.seek(0)
handler.client.put_object(fp, f)
except Exception as e:
raise e
@staticmethod
def load(handler, fp: str, **kwargs) -> Dict:
"""
Args:
fp (str): Path to save, eg. ali://opennlplab/model_weights/xxx/ddd.pt
"""
try:
object_stream = handler.client.get_object(fp)
buffer = io.BytesIO(object_stream.read())
states = torch.load(buffer, **kwargs)
except Exception as e:
raise e
return states
@staticmethod
def assert_fp_exists(handler, fp: str): # pylint: disable=W0613
assert len(list(handler.handler.ObjectIteratorV2(handler.client, prefix=fp))) > 0, fp
@staticmethod
def is_fp_exists(handler, fp: str): # pylint: disable=W0613
return len(list(handler.handler.ObjectIteratorV2(handler.client, prefix=fp))) > 0
@staticmethod
def get_fns(handler, fp: str):
if AliClient.is_fp_exists(handler, fp):
folder_name_list = []
for obj in handler.handler.ObjectIteratorV2(handler.client, prefix=fp):
folder_name_list.append(obj.key.split("/")[-1])
return list(set(folder_name_list))
else:
if gpc.is_rank_for_log():
logger.warning(f"'{fp}' not found!")
return None
@staticmethod
def async_upload_fileobj(handler, fp: str, local_nvme_path: str):
try:
handler.client.put_object_from_file(fp, local_nvme_path)
except Exception as e:
raise e
@staticmethod
def delete_obj(handler, fp: str):
raise NotImplementedError("ali not support delete_obj")
class LocalClient(StorageClient):
"""
Storage Client for local NFS.
@ -582,9 +732,8 @@ def get_volc_meta(fp: str, tmp_local_folder: str, is_async: bool) -> VolcMetaInf
match = volc_url_re.match(parts[0])
assert match is not None, f"url '{fp}' is not a valid volc url"
bucket_name, endpoint = match.group(1), match.group(2)
temp_part = endpoint.split(".")
endpoint = ".".join(temp_part[1:])
region = temp_part[1].split("-")
region = endpoint.split(".")
region = region[0].split("-")
region = "-".join(region[1:])
if is_async:
@ -603,8 +752,32 @@ def get_volc_meta(fp: str, tmp_local_folder: str, is_async: bool) -> VolcMetaInf
)
def get_ali_meta(fp: str, tmp_local_folder: str, is_async: bool) -> AliMetaInfo:
assert fp.startswith("ali://"), f"Path '{fp}' is not a ali url"
parts = fp.lstrip("ali://").split(os.path.sep)
match = ali_url_re.match(parts[0])
assert match is not None, f"url '{fp}' is not a valid ali url"
bucket_name, endpoint = match.group(1), match.group(2)
if is_async:
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
else:
tmp_step_file = None
return AliMetaInfo(
is_async=is_async,
handler=None,
bucket_name=bucket_name,
endpoint=endpoint,
file_path=os.path.sep.join(parts[1:]),
async_upload_fn=AliClient.async_upload_fileobj,
local_nvme_path=tmp_step_file,
)
def get_local_meta(fp: str) -> LocalMetaInfo:
assert not fp.startswith("s3://") and not fp.startswith("vc://"), f"Path '{fp}' is not a local path"
assert (
not fp.startswith("s3://") and not fp.startswith("vc://") and not fp.startswith("ali://")
), f"Path '{fp}' is not a local path"
return LocalMetaInfo(fp)
@ -645,11 +818,12 @@ class StorageManager(metaclass=SingletonMeta):
TODO: add a thread to poll the asynchronous storage state.
"""
BACKEND_TYPE = {"boto3", "local", "volc"}
BACKEND_TYPE = {"boto3", "local", "volc", "oss2"}
BACKEND_INIT_METHOD = {
"boto3": Boto3Client,
"local": LocalClient,
"volc": VolcClient,
"oss2": AliClient,
}
CLI_DICT = {}
@ -692,12 +866,15 @@ class StorageManager(metaclass=SingletonMeta):
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]:
def _get_client(
self, path: str, async_mode: bool = False
) -> Union[Boto3MetaInfo, VolcMetaInfo, AliMetaInfo, LocalMetaInfo]:
"""
example:
local:/path/to/checkpoint
boto3:s3://model_weights/0331/120bi
volc:vc://model_weights/0331/120bi
oss2:ali://model_weights/0331/120bi
Args:
path (str): _description_
@ -743,10 +920,29 @@ class StorageManager(metaclass=SingletonMeta):
the proxy may make volc unavailable or affect performance."
)
self.has_warning = True
elif backend == "oss2":
meta_info = get_ali_meta(path, self.tmp_local_folder, async_mode)
backend_key = backend + ":" + meta_info.endpoint
init_args = (
meta_info.bucket_name,
meta_info.endpoint,
)
if (
"http_proxy" in os.environ
or "https_proxy" in os.environ
or "HTTP_PROXY" in os.environ
or "HTTPS_PROXY" in os.environ
):
if not self.has_warning and gpc.is_rank_for_log():
logger.warning(
"HTTP/HTTPS proxy is detected when using oss2, incorrectly setting \
the proxy may make oss2 unavailable or affect performance."
)
self.has_warning = True
assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}"
# boto3 and volc backend need special treatment.
# boto3, volc and oss2 backend need special treatment.
if backend_key not in StorageManager.CLI_DICT:
StorageManager.CLI_DICT.update({backend_key: StorageManager.BACKEND_INIT_METHOD[backend](*init_args)})
@ -766,7 +962,11 @@ class StorageManager(metaclass=SingletonMeta):
if async_upload is None:
async_upload = self.async_mode
if not save_path.startswith("boto3:") and not save_path.startswith("volc:"):
if (
not save_path.startswith("boto3:")
and not save_path.startswith("volc:")
and not save_path.startswith("oss2:")
):
async_upload = False
meta = self._get_client(save_path, async_upload)

View File

@ -22,6 +22,9 @@ BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
VOLC_SAVE_PATH = f"volc:vc://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
VOLC_SAVE_PATH_NO_PRFIX = f"vc://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
ALI_SAVE_PATH = f"oss2:ali://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
ALI_SAVE_PATH_NO_PRFIX = f"ali://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
ASYNC_TMP_FOLDER = "./async_tmp_folder"
@ -197,3 +200,14 @@ def del_tmp_file():
print(presults, flush=True)
except: # noqa # pylint: disable=bare-except
pass
try:
cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + ALI_SAVE_PATH_NO_PRFIX + " / "
with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output:
results, presults = "", ""
for line in iter(output.stdout.readline, b""):
results += str(line.rstrip())
presults += line.rstrip().decode() + "\n"
print(presults, flush=True)
except: # noqa # pylint: disable=bare-except
pass

View File

@ -6,6 +6,7 @@ import torch
from internlm.core.context.parallel_context import Config
from internlm.initialize.launch import get_config_value
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
ALI_SAVE_PATH,
BOTO_SAVE_PATH,
LOCAL_SAVE_PATH,
VOLC_SAVE_PATH,
@ -63,6 +64,22 @@ ckpt_config_list = [
save_folder=VOLC_SAVE_PATH,
test_id=6,
),
# async ali
dict(
enable_save_ckpt=True,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
async_upload=True,
save_folder=ALI_SAVE_PATH,
test_id=7,
),
# sync ali
dict(
enable_save_ckpt=True,
async_upload_tmp_folder=None,
async_upload=False,
save_folder=ALI_SAVE_PATH,
test_id=8,
),
]
@ -89,7 +106,7 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg
ckpt_config = Config(ckpt_config)
if os.environ.get("OSS_BUCKET_NAME") is None:
if ckpt_config.test_id > 2:
print("Pass boto3 and volc", flush=True)
print("Pass boto3, volc and ali", flush=True)
return
enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
@ -120,6 +137,9 @@ internlm_ckpt_path = [
("volc:vc://oss_bucket/", "volc", "vc://oss_bucket/"),
("volc:oss_bucket/", "volc", "oss_bucket/"),
("vc://oss_bucket/", "volc", "vc://oss_bucket/"),
("oss2:ali://oss_bucket/", "oss2", "ali://oss_bucket/"),
("oss2:oss_bucket/", "oss2", "oss_bucket/"),
("ali://oss_bucket/", "oss2", "ali://oss_bucket/"),
]