mirror of https://github.com/InternLM/InternLM
fix(storage): fix and refactor storage api (#281)
parent
8d8d811e10
commit
8acf823a04
|
@ -46,12 +46,12 @@ def get_fns(fp: str):
|
|||
return storage_manager.get_fns(fp)
|
||||
|
||||
|
||||
def llm_load(fp: str, *args, **kwargs):
|
||||
return storage_manager.load(fp, *args, **kwargs)
|
||||
def llm_load(fp: str, **kwargs):
|
||||
return storage_manager.load(fp, **kwargs)
|
||||
|
||||
|
||||
def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
|
||||
storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
|
||||
def llm_save(save_path: str, saved_obj: Any, **kwargs):
|
||||
storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs)
|
||||
|
||||
|
||||
class StorageClient:
|
||||
|
@ -63,19 +63,23 @@ class StorageClient:
|
|||
self.handler = handler
|
||||
|
||||
@staticmethod
|
||||
def load(client, load_path: str, *args, **kwargs):
|
||||
def load(*args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def sync_upload_fileobj(*args, saved_obj=None, **kwargs):
|
||||
def sync_upload_fileobj(*args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def assert_fp_exists(client):
|
||||
def async_upload_fileobj(*args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_fns(client):
|
||||
def assert_fp_exists(*args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_fns(*args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@ -92,40 +96,65 @@ class Boto3MetaInfo:
|
|||
async_upload_fn: callable,
|
||||
local_nvme_path=None,
|
||||
) -> None:
|
||||
self.is_async = is_async
|
||||
# all need info.
|
||||
self.client = handler
|
||||
self.bucket_name = bucket_name
|
||||
self.endpoint = endpoint
|
||||
self.file_path = file_path
|
||||
self.async_upload_fn = async_upload_fn
|
||||
# only save need info.
|
||||
self.local_nvme_path = local_nvme_path
|
||||
self.is_async = is_async
|
||||
self.endpoint = endpoint
|
||||
self.async_upload_fn = async_upload_fn
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
|
||||
local_nvme_path: {self.local_nvme_path}"
|
||||
|
||||
@staticmethod
|
||||
def unpack_boto3_save_meta(meta):
|
||||
if meta.is_async:
|
||||
return meta.client, meta.bucket_name, meta.file_path, meta.local_nvme_path
|
||||
else:
|
||||
return meta.client, meta.bucket_name, meta.file_path
|
||||
|
||||
@staticmethod
|
||||
def unpack_boto3_nosave_meta(meta):
|
||||
return meta.client, meta.bucket_name, meta.file_path
|
||||
|
||||
|
||||
class LocalMetaInfo:
|
||||
"""Local meta info for save/load etc."""
|
||||
|
||||
def __init__(self, handler: StorageClient, dest_path: str) -> None:
|
||||
self.is_async = False
|
||||
self.client = handler
|
||||
self.dest_path = dest_path
|
||||
def __init__(self, file_path: str) -> None:
|
||||
self.file_path = file_path
|
||||
self.async_upload_fn = None
|
||||
self.is_async = False
|
||||
|
||||
@staticmethod
|
||||
def unpack_local_save_meta(meta):
|
||||
return (meta.file_path,)
|
||||
|
||||
@staticmethod
|
||||
def unpack_local_nosave_meta(meta):
|
||||
return (meta.file_path,)
|
||||
|
||||
|
||||
def unpack_meta(meta):
|
||||
args = []
|
||||
is_async = meta.is_async
|
||||
for k, v in meta.__dict__.items():
|
||||
if k in ("endpoint", "async_upload_fn", "is_async"):
|
||||
continue
|
||||
if not is_async and k in ("local_nvme_path",):
|
||||
continue
|
||||
args.append(v)
|
||||
def unpack_save_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
|
||||
if isinstance(meta, Boto3MetaInfo):
|
||||
return Boto3MetaInfo.unpack_boto3_save_meta(meta)
|
||||
elif isinstance(meta, LocalMetaInfo):
|
||||
return LocalMetaInfo.unpack_local_save_meta(meta)
|
||||
else:
|
||||
raise ValueError(f"unkonwn meta info: {type(meta)}")
|
||||
|
||||
return args
|
||||
|
||||
def unpack_nosave_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
|
||||
if isinstance(meta, Boto3MetaInfo):
|
||||
return Boto3MetaInfo.unpack_boto3_nosave_meta(meta)
|
||||
elif isinstance(meta, LocalMetaInfo):
|
||||
return LocalMetaInfo.unpack_local_nosave_meta(meta)
|
||||
else:
|
||||
raise ValueError(f"unkonwn meta info: {type(meta)}")
|
||||
|
||||
|
||||
def compute_file_md5_by_chunk(file_name: str):
|
||||
|
@ -205,13 +234,11 @@ class Boto3Client(StorageClient):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def sync_upload_fileobj(
|
||||
handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs
|
||||
): # pylint: disable=W0613
|
||||
def sync_upload_fileobj(handler, bucket_name: str, fp: str, saved_obj=None, **kwargs):
|
||||
assert saved_obj is not None, "saved_obj is None!"
|
||||
try:
|
||||
with io.BytesIO() as f:
|
||||
torch.save(saved_obj, f, *args, **kwargs)
|
||||
torch.save(saved_obj, f, **kwargs)
|
||||
f.seek(0)
|
||||
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
|
||||
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
||||
|
@ -220,14 +247,7 @@ class Boto3Client(StorageClient):
|
|||
) from exc
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
handler,
|
||||
bucket_name: str,
|
||||
fp: str,
|
||||
local_nvme_path: str, # pylint: disable=W0613
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
def load(handler, bucket_name: str, fp: str, **kwargs) -> Dict:
|
||||
"""
|
||||
Args:
|
||||
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
|
||||
|
@ -236,7 +256,7 @@ class Boto3Client(StorageClient):
|
|||
with io.BytesIO() as f:
|
||||
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
|
||||
f.seek(0)
|
||||
states = torch.load(f, *args, **kwargs)
|
||||
states = torch.load(f, **kwargs)
|
||||
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
||||
raise RuntimeError(
|
||||
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
|
||||
|
@ -244,11 +264,11 @@ class Boto3Client(StorageClient):
|
|||
return states
|
||||
|
||||
@staticmethod
|
||||
def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
|
||||
def assert_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
|
||||
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
|
||||
|
||||
@staticmethod
|
||||
def is_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
|
||||
def is_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
|
||||
re = handler.client.list_objects(Bucket=bucket_name, Prefix=fp)
|
||||
if "Contents" in re:
|
||||
return len(list(re["Contents"])) > 0
|
||||
|
@ -256,12 +276,12 @@ class Boto3Client(StorageClient):
|
|||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613
|
||||
def get_fns(handler, bucket_name: str, fp: str):
|
||||
"""
|
||||
Ref: https://stackoverflow.com/questions/54314563/
|
||||
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
|
||||
"""
|
||||
if Boto3Client.is_fp_exists(handler, bucket_name, fp, None):
|
||||
if Boto3Client.is_fp_exists(handler, bucket_name, fp):
|
||||
paginator = handler.client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
|
||||
folder_name_list = []
|
||||
|
@ -302,30 +322,26 @@ class LocalClient(StorageClient):
|
|||
super().__init__(None)
|
||||
|
||||
@staticmethod
|
||||
def sync_upload_fileobj(handler, fp: str, *args, saved_obj=None, **kwargs):
|
||||
assert isinstance(handler, LocalClient)
|
||||
def sync_upload_fileobj(fp: str, saved_obj=None, **kwargs):
|
||||
assert saved_obj is not None
|
||||
fp_dirname = os.path.dirname(fp)
|
||||
if not os.path.exists(fp_dirname):
|
||||
os.makedirs(fp_dirname, exist_ok=True)
|
||||
torch.save(saved_obj, fp, *args, **kwargs)
|
||||
torch.save(saved_obj, fp, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load(handler, fp: str, *args, **kwargs): # pylint: disable=W0613
|
||||
assert isinstance(handler, LocalClient)
|
||||
assert os.path.exists(fp), f"{fp} is not found!"
|
||||
with open(fp, "rb") as f:
|
||||
states = torch.load(f, *args, **kwargs)
|
||||
def load(load_path: str, **kwargs):
|
||||
assert os.path.exists(load_path), f"{load_path} is not found!"
|
||||
with open(load_path, "rb") as f:
|
||||
states = torch.load(f, **kwargs)
|
||||
return states
|
||||
|
||||
@staticmethod
|
||||
def assert_fp_exists(handler, folder):
|
||||
assert isinstance(handler, LocalClient)
|
||||
def assert_fp_exists(folder):
|
||||
assert os.path.exists(folder), folder
|
||||
|
||||
@staticmethod
|
||||
def get_fns(handler, folder):
|
||||
assert isinstance(handler, LocalClient)
|
||||
def get_fns(folder):
|
||||
if not os.path.exists(folder):
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(f"'{folder}' not found!")
|
||||
|
@ -334,8 +350,7 @@ class LocalClient(StorageClient):
|
|||
return os.listdir(folder)
|
||||
|
||||
@staticmethod
|
||||
def delete_obj(handler, fp: str):
|
||||
assert isinstance(handler, LocalClient)
|
||||
def delete_obj(fp: str):
|
||||
if not os.path.isdir(fp):
|
||||
os.remove(fp)
|
||||
|
||||
|
@ -359,7 +374,10 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
|
|||
assert match is not None, f"url '{fp}' is not a valid boto3 url"
|
||||
bucket_name, endpoint = match.group(1), match.group(2)
|
||||
endpoint = "http://" + endpoint + ":80"
|
||||
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
|
||||
if is_async:
|
||||
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
|
||||
else:
|
||||
tmp_step_file = None
|
||||
return Boto3MetaInfo(
|
||||
is_async=is_async,
|
||||
handler=None,
|
||||
|
@ -373,7 +391,7 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
|
|||
|
||||
def get_local_meta(fp: str) -> LocalMetaInfo:
|
||||
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
|
||||
return LocalMetaInfo(None, fp)
|
||||
return LocalMetaInfo(fp)
|
||||
|
||||
|
||||
def get_mount_point_free_size(path: str):
|
||||
|
@ -459,7 +477,7 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
|
||||
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
|
||||
|
||||
def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
|
||||
def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, LocalMetaInfo]:
|
||||
"""
|
||||
example:
|
||||
local:/path/to/checkpoint
|
||||
|
@ -475,7 +493,7 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
meta_info = get_local_meta(path)
|
||||
backend_key = backend
|
||||
elif backend == "boto3":
|
||||
meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode)
|
||||
meta_info = get_boto3_meta(path, self.tmp_local_folder, async_mode)
|
||||
backend_key = backend + ":" + meta_info.endpoint
|
||||
init_args = (meta_info.endpoint,)
|
||||
if (
|
||||
|
@ -503,17 +521,22 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
|
||||
def assert_fp_exists(self, folder) -> None:
|
||||
meta = self._get_client(path=folder)
|
||||
meta.client.assert_fp_exists(*unpack_meta(meta))
|
||||
meta.client.assert_fp_exists(*unpack_nosave_meta(meta))
|
||||
|
||||
def get_fns(self, folder) -> List[str]:
|
||||
meta = self._get_client(path=folder)
|
||||
return meta.client.get_fns(*unpack_meta(meta))
|
||||
return meta.client.get_fns(*unpack_nosave_meta(meta))
|
||||
|
||||
def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs):
|
||||
meta = self._get_client(path=save_path)
|
||||
def save(self, save_path: str, to_save_obj: Any, async_upload=None, **kwargs):
|
||||
|
||||
if async_upload is None:
|
||||
async_upload = self.async_mode
|
||||
|
||||
if not save_path.startswith("boto3:"):
|
||||
async_upload = False
|
||||
|
||||
meta = self._get_client(save_path, async_upload)
|
||||
|
||||
if async_upload:
|
||||
assert (
|
||||
self.tmp_local_folder
|
||||
|
@ -521,22 +544,22 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
tmp_step_file = meta.local_nvme_path
|
||||
self._to_be_del_files.append(tmp_step_file)
|
||||
with open(tmp_step_file, "wb") as f:
|
||||
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
|
||||
torch.save(to_save_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.async_executor(meta.async_upload_fn, *unpack_save_meta(meta))
|
||||
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
||||
self.async_task_peeding = True
|
||||
else:
|
||||
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
|
||||
meta.client.sync_upload_fileobj(*unpack_save_meta(meta), saved_obj=to_save_obj, **kwargs)
|
||||
self.upload_count += 1
|
||||
|
||||
def load(self, load_path: str, *args, **kwargs) -> Any:
|
||||
def load(self, load_path: str, **kwargs) -> Any:
|
||||
self.wait()
|
||||
meta = self._get_client(path=load_path)
|
||||
return meta.client.load(*unpack_meta(meta), *args, **kwargs)
|
||||
return meta.client.load(*unpack_nosave_meta(meta), **kwargs)
|
||||
|
||||
def delete_obj(self, fp: str):
|
||||
meta = self._get_client(path=fp)
|
||||
meta.client.delete_obj(*unpack_meta(meta))
|
||||
meta.client.delete_obj(*unpack_nosave_meta(meta))
|
||||
|
||||
def _del_tmp_folder(self):
|
||||
for fp in self._to_be_del_files:
|
||||
|
@ -626,7 +649,7 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
if self.async_mode and self.latest_save_folder:
|
||||
self.save(
|
||||
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
|
||||
saved_obj=dict({"step": self.latest_save_step}),
|
||||
to_save_obj=dict({"step": self.latest_save_step}),
|
||||
async_upload=False,
|
||||
)
|
||||
self.latest_save_folder = None
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import os
|
||||
import shutil
|
||||
from subprocess import PIPE, STDOUT, Popen
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -8,6 +10,18 @@ from internlm.core.context.parallel_context import Config
|
|||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||
from internlm.utils.common import SingletonMeta
|
||||
|
||||
OSS_NAME = os.environ["OSS_BUCKET_NAME"]
|
||||
OSS_IP = os.environ["OSS_IP"]
|
||||
USER = os.environ["USER"]
|
||||
JOB_NAME = "CI_TEST"
|
||||
LOCAL_SAVE_PATH = "local:local_ckpt"
|
||||
|
||||
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
|
||||
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
|
||||
|
||||
ASYNC_TMP_FOLDER = "./async_tmp_folder"
|
||||
|
||||
|
||||
# 1B
|
||||
init_config = Config(
|
||||
dict(
|
||||
|
@ -108,8 +122,10 @@ def reset_singletons():
|
|||
|
||||
def reset_seed():
|
||||
from internlm.core.context.random import _SEED_MANAGER
|
||||
|
||||
_SEED_MANAGER.reset()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def init_dist_and_model():
|
||||
from internlm.initialize import initialize_distributed_env
|
||||
|
@ -136,8 +152,30 @@ def init_dist_and_model():
|
|||
reset_seed()
|
||||
|
||||
|
||||
|
||||
def enter_flag(text):
|
||||
print(f"{text} begin!", flush=True)
|
||||
yield
|
||||
print(f"{text} end!", flush=True)
|
||||
|
||||
|
||||
def del_tmp_file():
|
||||
try:
|
||||
shutil.rmtree(ASYNC_TMP_FOLDER, ignore_errors=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
shutil.rmtree(LOCAL_SAVE_PATH.split(":")[1], ignore_errors=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + BOTO_SAVE_PATH_NO_PRFIX + " / "
|
||||
with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output:
|
||||
results, presults = "", ""
|
||||
for line in iter(output.stdout.readline, b""):
|
||||
results += str(line.rstrip())
|
||||
presults += line.rstrip().decode() + "\n"
|
||||
print(presults, flush=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
import os
|
||||
import shutil
|
||||
from subprocess import PIPE, STDOUT, Popen
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -13,6 +11,10 @@ from internlm.utils.common import SingletonMeta
|
|||
from internlm.utils.model_checkpoint import CheckpointManager
|
||||
from internlm.utils.storage_manager import wait_async_upload_finish
|
||||
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
||||
ASYNC_TMP_FOLDER,
|
||||
BOTO_SAVE_PATH,
|
||||
LOCAL_SAVE_PATH,
|
||||
del_tmp_file,
|
||||
init_dist_and_model,
|
||||
reset_singletons,
|
||||
)
|
||||
|
@ -21,39 +23,6 @@ TOTAL_STEP = 6
|
|||
|
||||
CKPT_EVERY = 4
|
||||
SNPASHOT_EVERY = 2
|
||||
OSS_NAME = os.environ["OSS_BUCKET_NAME"]
|
||||
OSS_IP = os.environ["OSS_IP"]
|
||||
USER = os.environ["USER"]
|
||||
JOB_NAME = "CI_TEST"
|
||||
LOCAL_SAVE_PATH = "local:local_ckpt"
|
||||
|
||||
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
|
||||
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
|
||||
|
||||
ASYNC_TMP_FOLDER = "./async_tmp_folder"
|
||||
|
||||
|
||||
def del_tmp_file():
|
||||
try:
|
||||
shutil.rmtree(ASYNC_TMP_FOLDER, ignore_errors=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
shutil.rmtree(LOCAL_SAVE_PATH.split(":")[1], ignore_errors=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + BOTO_SAVE_PATH_NO_PRFIX + " / "
|
||||
with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output:
|
||||
results, presults = "", ""
|
||||
for line in iter(output.stdout.readline, b""):
|
||||
results += str(line.rstrip())
|
||||
presults += line.rstrip().decode() + "\n"
|
||||
print(presults, flush=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
ckpt_config_list = [
|
||||
|
|
|
@ -1,21 +1,75 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from internlm.core.context.parallel_context import Config
|
||||
from internlm.initialize.launch import get_config_value
|
||||
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
||||
ASYNC_TMP_FOLDER,
|
||||
BOTO_SAVE_PATH,
|
||||
TOTAL_STEP,
|
||||
ckpt_config_list,
|
||||
LOCAL_SAVE_PATH,
|
||||
del_tmp_file,
|
||||
init_dist_and_model,
|
||||
reset_singletons,
|
||||
)
|
||||
|
||||
ASYNC_TMP_FOLDER = "./async_tmp_folder"
|
||||
ckpt_config_list = [
|
||||
# async boto
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
async_upload=True,
|
||||
save_folder=BOTO_SAVE_PATH,
|
||||
test_id=0,
|
||||
),
|
||||
# sync local
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
async_upload_tmp_folder=None,
|
||||
async_upload=False,
|
||||
save_folder=LOCAL_SAVE_PATH,
|
||||
test_id=1,
|
||||
),
|
||||
# sync boto
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
async_upload_tmp_folder=None,
|
||||
async_upload=False,
|
||||
save_folder=BOTO_SAVE_PATH,
|
||||
test_id=2,
|
||||
),
|
||||
# async local
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
async_upload=True,
|
||||
save_folder=LOCAL_SAVE_PATH,
|
||||
test_id=3,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def del_tmp():
|
||||
del_tmp_file()
|
||||
yield
|
||||
del_tmp_file()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("del_tmp")
|
||||
@pytest.mark.usefixtures("reset_singletons")
|
||||
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
|
||||
def test_storage_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-argument
|
||||
from internlm.utils.storage_manager import get_storage_manager, init_storage_manager
|
||||
def test_storage_mm_save_load(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-argument
|
||||
from internlm.utils.storage_manager import (
|
||||
check_folder,
|
||||
get_fns,
|
||||
init_storage_manager,
|
||||
llm_load,
|
||||
llm_save,
|
||||
wait_async_upload_finish,
|
||||
)
|
||||
|
||||
ckpt_config = Config(ckpt_config)
|
||||
enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
|
||||
|
@ -23,4 +77,13 @@ def test_storage_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable
|
|||
async_upload = get_config_value(ckpt_config, "async_upload", False)
|
||||
|
||||
init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload)
|
||||
get_storage_manager()
|
||||
|
||||
tobj = torch.rand(64, 64)
|
||||
save_fn = os.path.join(ckpt_config.save_folder, "test.pt")
|
||||
llm_save(save_fn, tobj)
|
||||
if ckpt_config.test_id == 0:
|
||||
wait_async_upload_finish()
|
||||
check_folder(save_fn)
|
||||
assert get_fns(ckpt_config.save_folder)[0] == "test.pt"
|
||||
load_obj = llm_load(save_fn, map_location="cpu")
|
||||
assert 0 == ((load_obj != tobj).sum())
|
||||
|
|
Loading…
Reference in New Issue