fix(storage): fix and refactor storage api (#281)

pull/282/head
Guoteng 2023-09-06 01:15:09 +08:00 committed by GitHub
parent 8d8d811e10
commit 8acf823a04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 205 additions and 112 deletions

View File

@ -46,12 +46,12 @@ def get_fns(fp: str):
return storage_manager.get_fns(fp)
def llm_load(fp: str, *args, **kwargs):
return storage_manager.load(fp, *args, **kwargs)
def llm_load(fp: str, **kwargs):
return storage_manager.load(fp, **kwargs)
def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
def llm_save(save_path: str, saved_obj: Any, **kwargs):
storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs)
class StorageClient:
@ -63,19 +63,23 @@ class StorageClient:
self.handler = handler
@staticmethod
def load(client, load_path: str, *args, **kwargs):
def load(*args, **kwargs):
raise NotImplementedError
@staticmethod
def sync_upload_fileobj(*args, saved_obj=None, **kwargs):
def sync_upload_fileobj(*args, **kwargs):
raise NotImplementedError
@staticmethod
def assert_fp_exists(client):
def async_upload_fileobj(*args, **kwargs):
raise NotImplementedError
@staticmethod
def get_fns(client):
def assert_fp_exists(*args, **kwargs):
raise NotImplementedError
@staticmethod
def get_fns(*args, **kwargs):
raise NotImplementedError
@ -92,40 +96,65 @@ class Boto3MetaInfo:
async_upload_fn: callable,
local_nvme_path=None,
) -> None:
self.is_async = is_async
# all need info.
self.client = handler
self.bucket_name = bucket_name
self.endpoint = endpoint
self.file_path = file_path
self.async_upload_fn = async_upload_fn
# only save need info.
self.local_nvme_path = local_nvme_path
self.is_async = is_async
self.endpoint = endpoint
self.async_upload_fn = async_upload_fn
def __str__(self) -> str:
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
local_nvme_path: {self.local_nvme_path}"
@staticmethod
def unpack_boto3_save_meta(meta):
if meta.is_async:
return meta.client, meta.bucket_name, meta.file_path, meta.local_nvme_path
else:
return meta.client, meta.bucket_name, meta.file_path
@staticmethod
def unpack_boto3_nosave_meta(meta):
return meta.client, meta.bucket_name, meta.file_path
class LocalMetaInfo:
"""Local meta info for save/load etc."""
def __init__(self, handler: StorageClient, dest_path: str) -> None:
self.is_async = False
self.client = handler
self.dest_path = dest_path
def __init__(self, file_path: str) -> None:
self.file_path = file_path
self.async_upload_fn = None
self.is_async = False
@staticmethod
def unpack_local_save_meta(meta):
return (meta.file_path,)
@staticmethod
def unpack_local_nosave_meta(meta):
return (meta.file_path,)
def unpack_meta(meta):
args = []
is_async = meta.is_async
for k, v in meta.__dict__.items():
if k in ("endpoint", "async_upload_fn", "is_async"):
continue
if not is_async and k in ("local_nvme_path",):
continue
args.append(v)
def unpack_save_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
if isinstance(meta, Boto3MetaInfo):
return Boto3MetaInfo.unpack_boto3_save_meta(meta)
elif isinstance(meta, LocalMetaInfo):
return LocalMetaInfo.unpack_local_save_meta(meta)
else:
raise ValueError(f"unkonwn meta info: {type(meta)}")
return args
def unpack_nosave_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
if isinstance(meta, Boto3MetaInfo):
return Boto3MetaInfo.unpack_boto3_nosave_meta(meta)
elif isinstance(meta, LocalMetaInfo):
return LocalMetaInfo.unpack_local_nosave_meta(meta)
else:
raise ValueError(f"unkonwn meta info: {type(meta)}")
def compute_file_md5_by_chunk(file_name: str):
@ -205,13 +234,11 @@ class Boto3Client(StorageClient):
)
@staticmethod
def sync_upload_fileobj(
handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs
): # pylint: disable=W0613
def sync_upload_fileobj(handler, bucket_name: str, fp: str, saved_obj=None, **kwargs):
assert saved_obj is not None, "saved_obj is None!"
try:
with io.BytesIO() as f:
torch.save(saved_obj, f, *args, **kwargs)
torch.save(saved_obj, f, **kwargs)
f.seek(0)
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
except handler.botocore.exceptions.EndpointConnectionError as exc:
@ -220,14 +247,7 @@ class Boto3Client(StorageClient):
) from exc
@staticmethod
def load(
handler,
bucket_name: str,
fp: str,
local_nvme_path: str, # pylint: disable=W0613
*args,
**kwargs,
) -> Dict:
def load(handler, bucket_name: str, fp: str, **kwargs) -> Dict:
"""
Args:
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
@ -236,7 +256,7 @@ class Boto3Client(StorageClient):
with io.BytesIO() as f:
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
f.seek(0)
states = torch.load(f, *args, **kwargs)
states = torch.load(f, **kwargs)
except handler.botocore.exceptions.EndpointConnectionError as exc:
raise RuntimeError(
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
@ -244,11 +264,11 @@ class Boto3Client(StorageClient):
return states
@staticmethod
def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
def assert_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
@staticmethod
def is_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
def is_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
re = handler.client.list_objects(Bucket=bucket_name, Prefix=fp)
if "Contents" in re:
return len(list(re["Contents"])) > 0
@ -256,12 +276,12 @@ class Boto3Client(StorageClient):
return False
@staticmethod
def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613
def get_fns(handler, bucket_name: str, fp: str):
"""
Ref: https://stackoverflow.com/questions/54314563/
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
"""
if Boto3Client.is_fp_exists(handler, bucket_name, fp, None):
if Boto3Client.is_fp_exists(handler, bucket_name, fp):
paginator = handler.client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
folder_name_list = []
@ -302,30 +322,26 @@ class LocalClient(StorageClient):
super().__init__(None)
@staticmethod
def sync_upload_fileobj(handler, fp: str, *args, saved_obj=None, **kwargs):
assert isinstance(handler, LocalClient)
def sync_upload_fileobj(fp: str, saved_obj=None, **kwargs):
assert saved_obj is not None
fp_dirname = os.path.dirname(fp)
if not os.path.exists(fp_dirname):
os.makedirs(fp_dirname, exist_ok=True)
torch.save(saved_obj, fp, *args, **kwargs)
torch.save(saved_obj, fp, **kwargs)
@staticmethod
def load(handler, fp: str, *args, **kwargs): # pylint: disable=W0613
assert isinstance(handler, LocalClient)
assert os.path.exists(fp), f"{fp} is not found!"
with open(fp, "rb") as f:
states = torch.load(f, *args, **kwargs)
def load(load_path: str, **kwargs):
assert os.path.exists(load_path), f"{load_path} is not found!"
with open(load_path, "rb") as f:
states = torch.load(f, **kwargs)
return states
@staticmethod
def assert_fp_exists(handler, folder):
assert isinstance(handler, LocalClient)
def assert_fp_exists(folder):
assert os.path.exists(folder), folder
@staticmethod
def get_fns(handler, folder):
assert isinstance(handler, LocalClient)
def get_fns(folder):
if not os.path.exists(folder):
if gpc.is_rank_for_log():
logger.warning(f"'{folder}' not found!")
@ -334,8 +350,7 @@ class LocalClient(StorageClient):
return os.listdir(folder)
@staticmethod
def delete_obj(handler, fp: str):
assert isinstance(handler, LocalClient)
def delete_obj(fp: str):
if not os.path.isdir(fp):
os.remove(fp)
@ -359,7 +374,10 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
assert match is not None, f"url '{fp}' is not a valid boto3 url"
bucket_name, endpoint = match.group(1), match.group(2)
endpoint = "http://" + endpoint + ":80"
if is_async:
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
else:
tmp_step_file = None
return Boto3MetaInfo(
is_async=is_async,
handler=None,
@ -373,7 +391,7 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
def get_local_meta(fp: str) -> LocalMetaInfo:
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
return LocalMetaInfo(None, fp)
return LocalMetaInfo(fp)
def get_mount_point_free_size(path: str):
@ -459,7 +477,7 @@ class StorageManager(metaclass=SingletonMeta):
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, LocalMetaInfo]:
"""
example:
local:/path/to/checkpoint
@ -475,7 +493,7 @@ class StorageManager(metaclass=SingletonMeta):
meta_info = get_local_meta(path)
backend_key = backend
elif backend == "boto3":
meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode)
meta_info = get_boto3_meta(path, self.tmp_local_folder, async_mode)
backend_key = backend + ":" + meta_info.endpoint
init_args = (meta_info.endpoint,)
if (
@ -503,17 +521,22 @@ class StorageManager(metaclass=SingletonMeta):
def assert_fp_exists(self, folder) -> None:
meta = self._get_client(path=folder)
meta.client.assert_fp_exists(*unpack_meta(meta))
meta.client.assert_fp_exists(*unpack_nosave_meta(meta))
def get_fns(self, folder) -> List[str]:
meta = self._get_client(path=folder)
return meta.client.get_fns(*unpack_meta(meta))
return meta.client.get_fns(*unpack_nosave_meta(meta))
def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs):
meta = self._get_client(path=save_path)
def save(self, save_path: str, to_save_obj: Any, async_upload=None, **kwargs):
if async_upload is None:
async_upload = self.async_mode
if not save_path.startswith("boto3:"):
async_upload = False
meta = self._get_client(save_path, async_upload)
if async_upload:
assert (
self.tmp_local_folder
@ -521,22 +544,22 @@ class StorageManager(metaclass=SingletonMeta):
tmp_step_file = meta.local_nvme_path
self._to_be_del_files.append(tmp_step_file)
with open(tmp_step_file, "wb") as f:
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
torch.save(to_save_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
self.async_executor(meta.async_upload_fn, *unpack_save_meta(meta))
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
self.async_task_peeding = True
else:
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
meta.client.sync_upload_fileobj(*unpack_save_meta(meta), saved_obj=to_save_obj, **kwargs)
self.upload_count += 1
def load(self, load_path: str, *args, **kwargs) -> Any:
def load(self, load_path: str, **kwargs) -> Any:
self.wait()
meta = self._get_client(path=load_path)
return meta.client.load(*unpack_meta(meta), *args, **kwargs)
return meta.client.load(*unpack_nosave_meta(meta), **kwargs)
def delete_obj(self, fp: str):
meta = self._get_client(path=fp)
meta.client.delete_obj(*unpack_meta(meta))
meta.client.delete_obj(*unpack_nosave_meta(meta))
def _del_tmp_folder(self):
for fp in self._to_be_del_files:
@ -626,7 +649,7 @@ class StorageManager(metaclass=SingletonMeta):
if self.async_mode and self.latest_save_folder:
self.save(
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
saved_obj=dict({"step": self.latest_save_step}),
to_save_obj=dict({"step": self.latest_save_step}),
async_upload=False,
)
self.latest_save_folder = None

0
tests/__init__.py Normal file
View File

View File

@ -1,4 +1,6 @@
import os
import shutil
from subprocess import PIPE, STDOUT, Popen
import pytest
import torch
@ -8,6 +10,18 @@ from internlm.core.context.parallel_context import Config
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
from internlm.utils.common import SingletonMeta
OSS_NAME = os.environ["OSS_BUCKET_NAME"]
OSS_IP = os.environ["OSS_IP"]
USER = os.environ["USER"]
JOB_NAME = "CI_TEST"
LOCAL_SAVE_PATH = "local:local_ckpt"
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
ASYNC_TMP_FOLDER = "./async_tmp_folder"
# 1B
init_config = Config(
dict(
@ -108,8 +122,10 @@ def reset_singletons():
def reset_seed():
from internlm.core.context.random import _SEED_MANAGER
_SEED_MANAGER.reset()
@pytest.fixture(scope="module")
def init_dist_and_model():
from internlm.initialize import initialize_distributed_env
@ -136,8 +152,30 @@ def init_dist_and_model():
reset_seed()
def enter_flag(text):
print(f"{text} begin!", flush=True)
yield
print(f"{text} end!", flush=True)
def del_tmp_file():
try:
shutil.rmtree(ASYNC_TMP_FOLDER, ignore_errors=True)
except FileNotFoundError:
pass
try:
shutil.rmtree(LOCAL_SAVE_PATH.split(":")[1], ignore_errors=True)
except FileNotFoundError:
pass
try:
cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + BOTO_SAVE_PATH_NO_PRFIX + " / "
with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output:
results, presults = "", ""
for line in iter(output.stdout.readline, b""):
results += str(line.rstrip())
presults += line.rstrip().decode() + "\n"
print(presults, flush=True)
except FileNotFoundError:
pass

View File

@ -1,6 +1,4 @@
import os
import shutil
from subprocess import PIPE, STDOUT, Popen
import pytest
import torch
@ -13,6 +11,10 @@ from internlm.utils.common import SingletonMeta
from internlm.utils.model_checkpoint import CheckpointManager
from internlm.utils.storage_manager import wait_async_upload_finish
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
ASYNC_TMP_FOLDER,
BOTO_SAVE_PATH,
LOCAL_SAVE_PATH,
del_tmp_file,
init_dist_and_model,
reset_singletons,
)
@ -21,39 +23,6 @@ TOTAL_STEP = 6
CKPT_EVERY = 4
SNPASHOT_EVERY = 2
OSS_NAME = os.environ["OSS_BUCKET_NAME"]
OSS_IP = os.environ["OSS_IP"]
USER = os.environ["USER"]
JOB_NAME = "CI_TEST"
LOCAL_SAVE_PATH = "local:local_ckpt"
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
ASYNC_TMP_FOLDER = "./async_tmp_folder"
def del_tmp_file():
try:
shutil.rmtree(ASYNC_TMP_FOLDER, ignore_errors=True)
except FileNotFoundError:
pass
try:
shutil.rmtree(LOCAL_SAVE_PATH.split(":")[1], ignore_errors=True)
except FileNotFoundError:
pass
try:
cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + BOTO_SAVE_PATH_NO_PRFIX + " / "
with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output:
results, presults = "", ""
for line in iter(output.stdout.readline, b""):
results += str(line.rstrip())
presults += line.rstrip().decode() + "\n"
print(presults, flush=True)
except FileNotFoundError:
pass
ckpt_config_list = [

View File

@ -1,21 +1,75 @@
import os
import pytest
import torch
from internlm.core.context.parallel_context import Config
from internlm.initialize.launch import get_config_value
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
ASYNC_TMP_FOLDER,
BOTO_SAVE_PATH,
TOTAL_STEP,
ckpt_config_list,
LOCAL_SAVE_PATH,
del_tmp_file,
init_dist_and_model,
reset_singletons,
)
ASYNC_TMP_FOLDER = "./async_tmp_folder"
ckpt_config_list = [
# async boto
dict(
enable_save_ckpt=True,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
async_upload=True,
save_folder=BOTO_SAVE_PATH,
test_id=0,
),
# sync local
dict(
enable_save_ckpt=True,
async_upload_tmp_folder=None,
async_upload=False,
save_folder=LOCAL_SAVE_PATH,
test_id=1,
),
# sync boto
dict(
enable_save_ckpt=True,
async_upload_tmp_folder=None,
async_upload=False,
save_folder=BOTO_SAVE_PATH,
test_id=2,
),
# async local
dict(
enable_save_ckpt=True,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
async_upload=True,
save_folder=LOCAL_SAVE_PATH,
test_id=3,
),
]
@pytest.fixture(scope="function")
def del_tmp():
del_tmp_file()
yield
del_tmp_file()
@pytest.mark.usefixtures("del_tmp")
@pytest.mark.usefixtures("reset_singletons")
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
def test_storage_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-argument
from internlm.utils.storage_manager import get_storage_manager, init_storage_manager
def test_storage_mm_save_load(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-argument
from internlm.utils.storage_manager import (
check_folder,
get_fns,
init_storage_manager,
llm_load,
llm_save,
wait_async_upload_finish,
)
ckpt_config = Config(ckpt_config)
enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
@ -23,4 +77,13 @@ def test_storage_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable
async_upload = get_config_value(ckpt_config, "async_upload", False)
init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload)
get_storage_manager()
tobj = torch.rand(64, 64)
save_fn = os.path.join(ckpt_config.save_folder, "test.pt")
llm_save(save_fn, tobj)
if ckpt_config.test_id == 0:
wait_async_upload_finish()
check_folder(save_fn)
assert get_fns(ckpt_config.save_folder)[0] == "test.pt"
load_obj = llm_load(save_fn, map_location="cpu")
assert 0 == ((load_obj != tobj).sum())