mirror of https://github.com/InternLM/InternLM
fix(storage): fix and refactor storage api (#281)
parent
8d8d811e10
commit
8acf823a04
|
@ -46,12 +46,12 @@ def get_fns(fp: str):
|
||||||
return storage_manager.get_fns(fp)
|
return storage_manager.get_fns(fp)
|
||||||
|
|
||||||
|
|
||||||
def llm_load(fp: str, *args, **kwargs):
|
def llm_load(fp: str, **kwargs):
|
||||||
return storage_manager.load(fp, *args, **kwargs)
|
return storage_manager.load(fp, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
|
def llm_save(save_path: str, saved_obj: Any, **kwargs):
|
||||||
storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
|
storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class StorageClient:
|
class StorageClient:
|
||||||
|
@ -63,19 +63,23 @@ class StorageClient:
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(client, load_path: str, *args, **kwargs):
|
def load(*args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sync_upload_fileobj(*args, saved_obj=None, **kwargs):
|
def sync_upload_fileobj(*args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def assert_fp_exists(client):
|
def async_upload_fileobj(*args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fns(client):
|
def assert_fp_exists(*args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_fns(*args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,40 +96,65 @@ class Boto3MetaInfo:
|
||||||
async_upload_fn: callable,
|
async_upload_fn: callable,
|
||||||
local_nvme_path=None,
|
local_nvme_path=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.is_async = is_async
|
# all need info.
|
||||||
self.client = handler
|
self.client = handler
|
||||||
self.bucket_name = bucket_name
|
self.bucket_name = bucket_name
|
||||||
self.endpoint = endpoint
|
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.async_upload_fn = async_upload_fn
|
# only save need info.
|
||||||
self.local_nvme_path = local_nvme_path
|
self.local_nvme_path = local_nvme_path
|
||||||
|
self.is_async = is_async
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.async_upload_fn = async_upload_fn
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
|
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
|
||||||
local_nvme_path: {self.local_nvme_path}"
|
local_nvme_path: {self.local_nvme_path}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unpack_boto3_save_meta(meta):
|
||||||
|
if meta.is_async:
|
||||||
|
return meta.client, meta.bucket_name, meta.file_path, meta.local_nvme_path
|
||||||
|
else:
|
||||||
|
return meta.client, meta.bucket_name, meta.file_path
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unpack_boto3_nosave_meta(meta):
|
||||||
|
return meta.client, meta.bucket_name, meta.file_path
|
||||||
|
|
||||||
|
|
||||||
class LocalMetaInfo:
|
class LocalMetaInfo:
|
||||||
"""Local meta info for save/load etc."""
|
"""Local meta info for save/load etc."""
|
||||||
|
|
||||||
def __init__(self, handler: StorageClient, dest_path: str) -> None:
|
def __init__(self, file_path: str) -> None:
|
||||||
self.is_async = False
|
self.file_path = file_path
|
||||||
self.client = handler
|
|
||||||
self.dest_path = dest_path
|
|
||||||
self.async_upload_fn = None
|
self.async_upload_fn = None
|
||||||
|
self.is_async = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unpack_local_save_meta(meta):
|
||||||
|
return (meta.file_path,)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unpack_local_nosave_meta(meta):
|
||||||
|
return (meta.file_path,)
|
||||||
|
|
||||||
|
|
||||||
def unpack_meta(meta):
|
def unpack_save_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
|
||||||
args = []
|
if isinstance(meta, Boto3MetaInfo):
|
||||||
is_async = meta.is_async
|
return Boto3MetaInfo.unpack_boto3_save_meta(meta)
|
||||||
for k, v in meta.__dict__.items():
|
elif isinstance(meta, LocalMetaInfo):
|
||||||
if k in ("endpoint", "async_upload_fn", "is_async"):
|
return LocalMetaInfo.unpack_local_save_meta(meta)
|
||||||
continue
|
else:
|
||||||
if not is_async and k in ("local_nvme_path",):
|
raise ValueError(f"unkonwn meta info: {type(meta)}")
|
||||||
continue
|
|
||||||
args.append(v)
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
def unpack_nosave_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
|
||||||
|
if isinstance(meta, Boto3MetaInfo):
|
||||||
|
return Boto3MetaInfo.unpack_boto3_nosave_meta(meta)
|
||||||
|
elif isinstance(meta, LocalMetaInfo):
|
||||||
|
return LocalMetaInfo.unpack_local_nosave_meta(meta)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unkonwn meta info: {type(meta)}")
|
||||||
|
|
||||||
|
|
||||||
def compute_file_md5_by_chunk(file_name: str):
|
def compute_file_md5_by_chunk(file_name: str):
|
||||||
|
@ -205,13 +234,11 @@ class Boto3Client(StorageClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sync_upload_fileobj(
|
def sync_upload_fileobj(handler, bucket_name: str, fp: str, saved_obj=None, **kwargs):
|
||||||
handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs
|
|
||||||
): # pylint: disable=W0613
|
|
||||||
assert saved_obj is not None, "saved_obj is None!"
|
assert saved_obj is not None, "saved_obj is None!"
|
||||||
try:
|
try:
|
||||||
with io.BytesIO() as f:
|
with io.BytesIO() as f:
|
||||||
torch.save(saved_obj, f, *args, **kwargs)
|
torch.save(saved_obj, f, **kwargs)
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
|
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
|
||||||
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
||||||
|
@ -220,14 +247,7 @@ class Boto3Client(StorageClient):
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(
|
def load(handler, bucket_name: str, fp: str, **kwargs) -> Dict:
|
||||||
handler,
|
|
||||||
bucket_name: str,
|
|
||||||
fp: str,
|
|
||||||
local_nvme_path: str, # pylint: disable=W0613
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
|
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
|
||||||
|
@ -236,7 +256,7 @@ class Boto3Client(StorageClient):
|
||||||
with io.BytesIO() as f:
|
with io.BytesIO() as f:
|
||||||
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
|
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
states = torch.load(f, *args, **kwargs)
|
states = torch.load(f, **kwargs)
|
||||||
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
|
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
|
||||||
|
@ -244,11 +264,11 @@ class Boto3Client(StorageClient):
|
||||||
return states
|
return states
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
|
def assert_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
|
||||||
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
|
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
|
def is_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
|
||||||
re = handler.client.list_objects(Bucket=bucket_name, Prefix=fp)
|
re = handler.client.list_objects(Bucket=bucket_name, Prefix=fp)
|
||||||
if "Contents" in re:
|
if "Contents" in re:
|
||||||
return len(list(re["Contents"])) > 0
|
return len(list(re["Contents"])) > 0
|
||||||
|
@ -256,12 +276,12 @@ class Boto3Client(StorageClient):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613
|
def get_fns(handler, bucket_name: str, fp: str):
|
||||||
"""
|
"""
|
||||||
Ref: https://stackoverflow.com/questions/54314563/
|
Ref: https://stackoverflow.com/questions/54314563/
|
||||||
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
|
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
|
||||||
"""
|
"""
|
||||||
if Boto3Client.is_fp_exists(handler, bucket_name, fp, None):
|
if Boto3Client.is_fp_exists(handler, bucket_name, fp):
|
||||||
paginator = handler.client.get_paginator("list_objects_v2")
|
paginator = handler.client.get_paginator("list_objects_v2")
|
||||||
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
|
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
|
||||||
folder_name_list = []
|
folder_name_list = []
|
||||||
|
@ -302,30 +322,26 @@ class LocalClient(StorageClient):
|
||||||
super().__init__(None)
|
super().__init__(None)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sync_upload_fileobj(handler, fp: str, *args, saved_obj=None, **kwargs):
|
def sync_upload_fileobj(fp: str, saved_obj=None, **kwargs):
|
||||||
assert isinstance(handler, LocalClient)
|
|
||||||
assert saved_obj is not None
|
assert saved_obj is not None
|
||||||
fp_dirname = os.path.dirname(fp)
|
fp_dirname = os.path.dirname(fp)
|
||||||
if not os.path.exists(fp_dirname):
|
if not os.path.exists(fp_dirname):
|
||||||
os.makedirs(fp_dirname, exist_ok=True)
|
os.makedirs(fp_dirname, exist_ok=True)
|
||||||
torch.save(saved_obj, fp, *args, **kwargs)
|
torch.save(saved_obj, fp, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(handler, fp: str, *args, **kwargs): # pylint: disable=W0613
|
def load(load_path: str, **kwargs):
|
||||||
assert isinstance(handler, LocalClient)
|
assert os.path.exists(load_path), f"{load_path} is not found!"
|
||||||
assert os.path.exists(fp), f"{fp} is not found!"
|
with open(load_path, "rb") as f:
|
||||||
with open(fp, "rb") as f:
|
states = torch.load(f, **kwargs)
|
||||||
states = torch.load(f, *args, **kwargs)
|
|
||||||
return states
|
return states
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def assert_fp_exists(handler, folder):
|
def assert_fp_exists(folder):
|
||||||
assert isinstance(handler, LocalClient)
|
|
||||||
assert os.path.exists(folder), folder
|
assert os.path.exists(folder), folder
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fns(handler, folder):
|
def get_fns(folder):
|
||||||
assert isinstance(handler, LocalClient)
|
|
||||||
if not os.path.exists(folder):
|
if not os.path.exists(folder):
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.warning(f"'{folder}' not found!")
|
logger.warning(f"'{folder}' not found!")
|
||||||
|
@ -334,8 +350,7 @@ class LocalClient(StorageClient):
|
||||||
return os.listdir(folder)
|
return os.listdir(folder)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_obj(handler, fp: str):
|
def delete_obj(fp: str):
|
||||||
assert isinstance(handler, LocalClient)
|
|
||||||
if not os.path.isdir(fp):
|
if not os.path.isdir(fp):
|
||||||
os.remove(fp)
|
os.remove(fp)
|
||||||
|
|
||||||
|
@ -359,7 +374,10 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
|
||||||
assert match is not None, f"url '{fp}' is not a valid boto3 url"
|
assert match is not None, f"url '{fp}' is not a valid boto3 url"
|
||||||
bucket_name, endpoint = match.group(1), match.group(2)
|
bucket_name, endpoint = match.group(1), match.group(2)
|
||||||
endpoint = "http://" + endpoint + ":80"
|
endpoint = "http://" + endpoint + ":80"
|
||||||
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
|
if is_async:
|
||||||
|
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
|
||||||
|
else:
|
||||||
|
tmp_step_file = None
|
||||||
return Boto3MetaInfo(
|
return Boto3MetaInfo(
|
||||||
is_async=is_async,
|
is_async=is_async,
|
||||||
handler=None,
|
handler=None,
|
||||||
|
@ -373,7 +391,7 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
|
||||||
|
|
||||||
def get_local_meta(fp: str) -> LocalMetaInfo:
|
def get_local_meta(fp: str) -> LocalMetaInfo:
|
||||||
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
|
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
|
||||||
return LocalMetaInfo(None, fp)
|
return LocalMetaInfo(fp)
|
||||||
|
|
||||||
|
|
||||||
def get_mount_point_free_size(path: str):
|
def get_mount_point_free_size(path: str):
|
||||||
|
@ -459,7 +477,7 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
|
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
|
||||||
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
|
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
|
||||||
|
|
||||||
def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
|
def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, LocalMetaInfo]:
|
||||||
"""
|
"""
|
||||||
example:
|
example:
|
||||||
local:/path/to/checkpoint
|
local:/path/to/checkpoint
|
||||||
|
@ -475,7 +493,7 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
meta_info = get_local_meta(path)
|
meta_info = get_local_meta(path)
|
||||||
backend_key = backend
|
backend_key = backend
|
||||||
elif backend == "boto3":
|
elif backend == "boto3":
|
||||||
meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode)
|
meta_info = get_boto3_meta(path, self.tmp_local_folder, async_mode)
|
||||||
backend_key = backend + ":" + meta_info.endpoint
|
backend_key = backend + ":" + meta_info.endpoint
|
||||||
init_args = (meta_info.endpoint,)
|
init_args = (meta_info.endpoint,)
|
||||||
if (
|
if (
|
||||||
|
@ -503,17 +521,22 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
|
|
||||||
def assert_fp_exists(self, folder) -> None:
|
def assert_fp_exists(self, folder) -> None:
|
||||||
meta = self._get_client(path=folder)
|
meta = self._get_client(path=folder)
|
||||||
meta.client.assert_fp_exists(*unpack_meta(meta))
|
meta.client.assert_fp_exists(*unpack_nosave_meta(meta))
|
||||||
|
|
||||||
def get_fns(self, folder) -> List[str]:
|
def get_fns(self, folder) -> List[str]:
|
||||||
meta = self._get_client(path=folder)
|
meta = self._get_client(path=folder)
|
||||||
return meta.client.get_fns(*unpack_meta(meta))
|
return meta.client.get_fns(*unpack_nosave_meta(meta))
|
||||||
|
|
||||||
def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs):
|
def save(self, save_path: str, to_save_obj: Any, async_upload=None, **kwargs):
|
||||||
meta = self._get_client(path=save_path)
|
|
||||||
|
|
||||||
if async_upload is None:
|
if async_upload is None:
|
||||||
async_upload = self.async_mode
|
async_upload = self.async_mode
|
||||||
|
|
||||||
|
if not save_path.startswith("boto3:"):
|
||||||
|
async_upload = False
|
||||||
|
|
||||||
|
meta = self._get_client(save_path, async_upload)
|
||||||
|
|
||||||
if async_upload:
|
if async_upload:
|
||||||
assert (
|
assert (
|
||||||
self.tmp_local_folder
|
self.tmp_local_folder
|
||||||
|
@ -521,22 +544,22 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
tmp_step_file = meta.local_nvme_path
|
tmp_step_file = meta.local_nvme_path
|
||||||
self._to_be_del_files.append(tmp_step_file)
|
self._to_be_del_files.append(tmp_step_file)
|
||||||
with open(tmp_step_file, "wb") as f:
|
with open(tmp_step_file, "wb") as f:
|
||||||
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
torch.save(to_save_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
|
self.async_executor(meta.async_upload_fn, *unpack_save_meta(meta))
|
||||||
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
||||||
self.async_task_peeding = True
|
self.async_task_peeding = True
|
||||||
else:
|
else:
|
||||||
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
|
meta.client.sync_upload_fileobj(*unpack_save_meta(meta), saved_obj=to_save_obj, **kwargs)
|
||||||
self.upload_count += 1
|
self.upload_count += 1
|
||||||
|
|
||||||
def load(self, load_path: str, *args, **kwargs) -> Any:
|
def load(self, load_path: str, **kwargs) -> Any:
|
||||||
self.wait()
|
self.wait()
|
||||||
meta = self._get_client(path=load_path)
|
meta = self._get_client(path=load_path)
|
||||||
return meta.client.load(*unpack_meta(meta), *args, **kwargs)
|
return meta.client.load(*unpack_nosave_meta(meta), **kwargs)
|
||||||
|
|
||||||
def delete_obj(self, fp: str):
|
def delete_obj(self, fp: str):
|
||||||
meta = self._get_client(path=fp)
|
meta = self._get_client(path=fp)
|
||||||
meta.client.delete_obj(*unpack_meta(meta))
|
meta.client.delete_obj(*unpack_nosave_meta(meta))
|
||||||
|
|
||||||
def _del_tmp_folder(self):
|
def _del_tmp_folder(self):
|
||||||
for fp in self._to_be_del_files:
|
for fp in self._to_be_del_files:
|
||||||
|
@ -626,7 +649,7 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
if self.async_mode and self.latest_save_folder:
|
if self.async_mode and self.latest_save_folder:
|
||||||
self.save(
|
self.save(
|
||||||
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
|
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
|
||||||
saved_obj=dict({"step": self.latest_save_step}),
|
to_save_obj=dict({"step": self.latest_save_step}),
|
||||||
async_upload=False,
|
async_upload=False,
|
||||||
)
|
)
|
||||||
self.latest_save_folder = None
|
self.latest_save_folder = None
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
|
from subprocess import PIPE, STDOUT, Popen
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -8,6 +10,18 @@ from internlm.core.context.parallel_context import Config
|
||||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||||
from internlm.utils.common import SingletonMeta
|
from internlm.utils.common import SingletonMeta
|
||||||
|
|
||||||
|
OSS_NAME = os.environ["OSS_BUCKET_NAME"]
|
||||||
|
OSS_IP = os.environ["OSS_IP"]
|
||||||
|
USER = os.environ["USER"]
|
||||||
|
JOB_NAME = "CI_TEST"
|
||||||
|
LOCAL_SAVE_PATH = "local:local_ckpt"
|
||||||
|
|
||||||
|
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
|
||||||
|
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
|
||||||
|
|
||||||
|
ASYNC_TMP_FOLDER = "./async_tmp_folder"
|
||||||
|
|
||||||
|
|
||||||
# 1B
|
# 1B
|
||||||
init_config = Config(
|
init_config = Config(
|
||||||
dict(
|
dict(
|
||||||
|
@ -108,8 +122,10 @@ def reset_singletons():
|
||||||
|
|
||||||
def reset_seed():
|
def reset_seed():
|
||||||
from internlm.core.context.random import _SEED_MANAGER
|
from internlm.core.context.random import _SEED_MANAGER
|
||||||
|
|
||||||
_SEED_MANAGER.reset()
|
_SEED_MANAGER.reset()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def init_dist_and_model():
|
def init_dist_and_model():
|
||||||
from internlm.initialize import initialize_distributed_env
|
from internlm.initialize import initialize_distributed_env
|
||||||
|
@ -136,8 +152,30 @@ def init_dist_and_model():
|
||||||
reset_seed()
|
reset_seed()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def enter_flag(text):
|
def enter_flag(text):
|
||||||
print(f"{text} begin!", flush=True)
|
print(f"{text} begin!", flush=True)
|
||||||
yield
|
yield
|
||||||
print(f"{text} end!", flush=True)
|
print(f"{text} end!", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def del_tmp_file():
|
||||||
|
try:
|
||||||
|
shutil.rmtree(ASYNC_TMP_FOLDER, ignore_errors=True)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
shutil.rmtree(LOCAL_SAVE_PATH.split(":")[1], ignore_errors=True)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + BOTO_SAVE_PATH_NO_PRFIX + " / "
|
||||||
|
with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output:
|
||||||
|
results, presults = "", ""
|
||||||
|
for line in iter(output.stdout.readline, b""):
|
||||||
|
results += str(line.rstrip())
|
||||||
|
presults += line.rstrip().decode() + "\n"
|
||||||
|
print(presults, flush=True)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
from subprocess import PIPE, STDOUT, Popen
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -13,6 +11,10 @@ from internlm.utils.common import SingletonMeta
|
||||||
from internlm.utils.model_checkpoint import CheckpointManager
|
from internlm.utils.model_checkpoint import CheckpointManager
|
||||||
from internlm.utils.storage_manager import wait_async_upload_finish
|
from internlm.utils.storage_manager import wait_async_upload_finish
|
||||||
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
||||||
|
ASYNC_TMP_FOLDER,
|
||||||
|
BOTO_SAVE_PATH,
|
||||||
|
LOCAL_SAVE_PATH,
|
||||||
|
del_tmp_file,
|
||||||
init_dist_and_model,
|
init_dist_and_model,
|
||||||
reset_singletons,
|
reset_singletons,
|
||||||
)
|
)
|
||||||
|
@ -21,39 +23,6 @@ TOTAL_STEP = 6
|
||||||
|
|
||||||
CKPT_EVERY = 4
|
CKPT_EVERY = 4
|
||||||
SNPASHOT_EVERY = 2
|
SNPASHOT_EVERY = 2
|
||||||
OSS_NAME = os.environ["OSS_BUCKET_NAME"]
|
|
||||||
OSS_IP = os.environ["OSS_IP"]
|
|
||||||
USER = os.environ["USER"]
|
|
||||||
JOB_NAME = "CI_TEST"
|
|
||||||
LOCAL_SAVE_PATH = "local:local_ckpt"
|
|
||||||
|
|
||||||
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
|
|
||||||
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
|
|
||||||
|
|
||||||
ASYNC_TMP_FOLDER = "./async_tmp_folder"
|
|
||||||
|
|
||||||
|
|
||||||
def del_tmp_file():
|
|
||||||
try:
|
|
||||||
shutil.rmtree(ASYNC_TMP_FOLDER, ignore_errors=True)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
shutil.rmtree(LOCAL_SAVE_PATH.split(":")[1], ignore_errors=True)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + BOTO_SAVE_PATH_NO_PRFIX + " / "
|
|
||||||
with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output:
|
|
||||||
results, presults = "", ""
|
|
||||||
for line in iter(output.stdout.readline, b""):
|
|
||||||
results += str(line.rstrip())
|
|
||||||
presults += line.rstrip().decode() + "\n"
|
|
||||||
print(presults, flush=True)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
ckpt_config_list = [
|
ckpt_config_list = [
|
||||||
|
|
|
@ -1,21 +1,75 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from internlm.core.context.parallel_context import Config
|
from internlm.core.context.parallel_context import Config
|
||||||
from internlm.initialize.launch import get_config_value
|
from internlm.initialize.launch import get_config_value
|
||||||
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
||||||
|
ASYNC_TMP_FOLDER,
|
||||||
BOTO_SAVE_PATH,
|
BOTO_SAVE_PATH,
|
||||||
TOTAL_STEP,
|
LOCAL_SAVE_PATH,
|
||||||
ckpt_config_list,
|
|
||||||
del_tmp_file,
|
del_tmp_file,
|
||||||
init_dist_and_model,
|
init_dist_and_model,
|
||||||
reset_singletons,
|
reset_singletons,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ASYNC_TMP_FOLDER = "./async_tmp_folder"
|
||||||
|
ckpt_config_list = [
|
||||||
|
# async boto
|
||||||
|
dict(
|
||||||
|
enable_save_ckpt=True,
|
||||||
|
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||||
|
async_upload=True,
|
||||||
|
save_folder=BOTO_SAVE_PATH,
|
||||||
|
test_id=0,
|
||||||
|
),
|
||||||
|
# sync local
|
||||||
|
dict(
|
||||||
|
enable_save_ckpt=True,
|
||||||
|
async_upload_tmp_folder=None,
|
||||||
|
async_upload=False,
|
||||||
|
save_folder=LOCAL_SAVE_PATH,
|
||||||
|
test_id=1,
|
||||||
|
),
|
||||||
|
# sync boto
|
||||||
|
dict(
|
||||||
|
enable_save_ckpt=True,
|
||||||
|
async_upload_tmp_folder=None,
|
||||||
|
async_upload=False,
|
||||||
|
save_folder=BOTO_SAVE_PATH,
|
||||||
|
test_id=2,
|
||||||
|
),
|
||||||
|
# async local
|
||||||
|
dict(
|
||||||
|
enable_save_ckpt=True,
|
||||||
|
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||||
|
async_upload=True,
|
||||||
|
save_folder=LOCAL_SAVE_PATH,
|
||||||
|
test_id=3,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def del_tmp():
|
||||||
|
del_tmp_file()
|
||||||
|
yield
|
||||||
|
del_tmp_file()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("del_tmp")
|
||||||
@pytest.mark.usefixtures("reset_singletons")
|
@pytest.mark.usefixtures("reset_singletons")
|
||||||
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
|
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
|
||||||
def test_storage_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-argument
|
def test_storage_mm_save_load(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-argument
|
||||||
from internlm.utils.storage_manager import get_storage_manager, init_storage_manager
|
from internlm.utils.storage_manager import (
|
||||||
|
check_folder,
|
||||||
|
get_fns,
|
||||||
|
init_storage_manager,
|
||||||
|
llm_load,
|
||||||
|
llm_save,
|
||||||
|
wait_async_upload_finish,
|
||||||
|
)
|
||||||
|
|
||||||
ckpt_config = Config(ckpt_config)
|
ckpt_config = Config(ckpt_config)
|
||||||
enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
|
enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
|
||||||
|
@ -23,4 +77,13 @@ def test_storage_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable
|
||||||
async_upload = get_config_value(ckpt_config, "async_upload", False)
|
async_upload = get_config_value(ckpt_config, "async_upload", False)
|
||||||
|
|
||||||
init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload)
|
init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload)
|
||||||
get_storage_manager()
|
|
||||||
|
tobj = torch.rand(64, 64)
|
||||||
|
save_fn = os.path.join(ckpt_config.save_folder, "test.pt")
|
||||||
|
llm_save(save_fn, tobj)
|
||||||
|
if ckpt_config.test_id == 0:
|
||||||
|
wait_async_upload_finish()
|
||||||
|
check_folder(save_fn)
|
||||||
|
assert get_fns(ckpt_config.save_folder)[0] == "test.pt"
|
||||||
|
load_obj = llm_load(save_fn, map_location="cpu")
|
||||||
|
assert 0 == ((load_obj != tobj).sum())
|
||||||
|
|
Loading…
Reference in New Issue