mirror of https://github.com/InternLM/InternLM
feat(storage): use multipart upload when using oss (#520)
* multipart upload * upload * storage * storage * storage * storagepull/526/head
parent
66bffffe5c
commit
1738bee002
|
@ -178,7 +178,8 @@ def args_sanity_check():
|
|||
else:
|
||||
if ckpt.async_upload:
|
||||
assert "save_ckpt_folder" in ckpt
|
||||
if "boto3:" not in ckpt.save_ckpt_folder:
|
||||
prefix_list = ["boto3:", "volc:", "oss2:"]
|
||||
if not any(ckpt.save_ckpt_folder.startswith(prefix) for prefix in prefix_list):
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(
|
||||
"Storing ckpt on file system does not support asynchronous storage, will use sync save!"
|
||||
|
|
|
@ -720,6 +720,7 @@ class CheckpointManager:
|
|||
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
|
||||
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"])
|
||||
|
||||
torch.distributed.barrier()
|
||||
# test storage setting is ok.
|
||||
if self.enable_save_ckpt:
|
||||
self.try_ping_storage()
|
||||
|
@ -1016,7 +1017,7 @@ now step_count is {train_state.step_count}",
|
|||
self.storage_manager.latest_save_step = step
|
||||
|
||||
def try_ping_storage(self):
|
||||
if gpc.get_global_rank() % 8 == 0:
|
||||
if gpc.is_rank_for_log():
|
||||
buff = torch.ones((1, 64, 64), dtype=torch.bfloat16)
|
||||
test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping")
|
||||
self.storage_manager.save(test_fn, buff)
|
||||
|
|
|
@ -18,10 +18,6 @@ from typing import Any, Awaitable, Callable, Dict, List, Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.common import SingletonMeta
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
try:
|
||||
import boto3
|
||||
import botocore
|
||||
|
@ -30,16 +26,38 @@ except ImportError:
|
|||
|
||||
try:
|
||||
import tos
|
||||
from tos.utils import SizeAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import oss2
|
||||
from oss2 import SizedFileAdapter, determine_part_size
|
||||
from oss2.models import PartInfo
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
logger = get_logger(__file__)
|
||||
class Logger:
|
||||
"Dummy logger"
|
||||
|
||||
def info(self, mesage: str):
|
||||
print(f"Info: {mesage}", flush=True)
|
||||
|
||||
def warning(self, mesage: str):
|
||||
print(f"Warning: {mesage}", flush=True)
|
||||
|
||||
def error(self, mesage: str):
|
||||
print(f"Error: {mesage}", flush=True)
|
||||
|
||||
|
||||
try:
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
except ImportError:
|
||||
logger = Logger()
|
||||
|
||||
|
||||
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
|
||||
volc_url_re = re.compile(r"^(.*?)\.(.*)$")
|
||||
|
@ -66,6 +84,12 @@ def llm_save(save_path: str, saved_obj: Any, **kwargs):
|
|||
storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs)
|
||||
|
||||
|
||||
def is_rank_for_log():
|
||||
if dist.is_initialized():
|
||||
return dist.get_rank() % 8 == 0
|
||||
return True
|
||||
|
||||
|
||||
class StorageClient:
|
||||
"""
|
||||
StorageClient as a client for s3 storage access.
|
||||
|
@ -267,21 +291,21 @@ def compute_file_md5_by_chunk(file_name: str):
|
|||
|
||||
def try_get_storage_backend(path: str):
|
||||
if path.startswith("s3:"):
|
||||
if gpc.is_rank_for_log():
|
||||
if is_rank_for_log():
|
||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
|
||||
return "boto3", path
|
||||
elif path.startswith("vc:"):
|
||||
if gpc.is_rank_for_log():
|
||||
if is_rank_for_log():
|
||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of volc.")
|
||||
return "volc", path
|
||||
elif path.startswith("ali:"):
|
||||
if gpc.is_rank_for_log():
|
||||
if is_rank_for_log():
|
||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of ali.")
|
||||
return "oss2", path
|
||||
else:
|
||||
sre = path.split(":", maxsplit=1)
|
||||
if len(sre) == 1:
|
||||
if gpc.is_rank_for_log():
|
||||
if is_rank_for_log():
|
||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.")
|
||||
return "local", sre[0]
|
||||
else:
|
||||
|
@ -399,7 +423,7 @@ class Boto3Client(StorageClient):
|
|||
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
|
||||
return list(set(folder_name_list))
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
if is_rank_for_log():
|
||||
logger.warning(f"'{fp}' not found!")
|
||||
return None
|
||||
|
||||
|
@ -530,14 +554,41 @@ class VolcClient(StorageClient):
|
|||
return list(set(folder_name_list))
|
||||
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
if is_rank_for_log():
|
||||
logger.warning(f"'{fp}' not found!")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
|
||||
try:
|
||||
handler.client.put_object_from_file(bucket_name, fp, local_nvme_path)
|
||||
total_size = os.path.getsize(local_nvme_path)
|
||||
part_size = 5 * 1024 * 1024
|
||||
|
||||
multi_result = handler.client.create_multipart_upload(bucket_name, fp)
|
||||
|
||||
upload_id = multi_result.upload_id
|
||||
parts = []
|
||||
|
||||
# Upload shard data
|
||||
with open(local_nvme_path, "rb") as f:
|
||||
part_number = 1
|
||||
offset = 0
|
||||
while offset < total_size:
|
||||
num_to_upload = min(part_size, total_size - offset)
|
||||
out = handler.client.upload_part(
|
||||
bucket_name,
|
||||
fp,
|
||||
upload_id,
|
||||
part_number,
|
||||
content=SizeAdapter(f, num_to_upload, init_offset=offset),
|
||||
)
|
||||
parts.append(out)
|
||||
offset += num_to_upload
|
||||
part_number += 1
|
||||
|
||||
# Complete the multipart upload task
|
||||
handler.client.complete_multipart_upload(bucket_name, fp, upload_id, parts)
|
||||
|
||||
except handler.handler.exceptions.TosClientError as exc:
|
||||
raise RuntimeError(
|
||||
f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}"
|
||||
|
@ -548,6 +599,8 @@ class VolcClient(StorageClient):
|
|||
f"error with request id: {exec.request_id}",
|
||||
f"error with message: {exec.message}",
|
||||
f"error with http code: {exec.status_code}",
|
||||
f"error with ec: {exec.ec}",
|
||||
f"error with request url: {exec.request_url}",
|
||||
) from exc
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
@ -570,10 +623,10 @@ class AliClient(StorageClient):
|
|||
"""Ali object/file storage management class
|
||||
|
||||
Args:
|
||||
access_key (str): Ali access key ID.
|
||||
access_key (str): Ali access key ID.s
|
||||
secret_key (str): Ali secret access key.
|
||||
endpoint (str): Ali tos endpoint.
|
||||
region (str): Ali tos region.
|
||||
bucket_name (str): Ali tos bucket_name.
|
||||
|
||||
"""
|
||||
super().__init__(oss2)
|
||||
|
@ -634,14 +687,34 @@ class AliClient(StorageClient):
|
|||
return list(set(folder_name_list))
|
||||
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
if is_rank_for_log():
|
||||
logger.warning(f"'{fp}' not found!")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def async_upload_fileobj(handler, fp: str, local_nvme_path: str):
|
||||
try:
|
||||
handler.client.put_object_from_file(fp, local_nvme_path)
|
||||
total_size = os.path.getsize(local_nvme_path)
|
||||
part_size = determine_part_size(total_size, preferred_size=5 * 1024 * 1024)
|
||||
upload_id = handler.client.init_multipart_upload(fp).upload_id
|
||||
parts = []
|
||||
with open(local_nvme_path, "rb") as fileobj:
|
||||
part_number = 1
|
||||
offset = 0
|
||||
while offset < total_size:
|
||||
num_to_upload = min(part_size, total_size - offset)
|
||||
# Calling the SizedFileAdapter method will generate a new file object
|
||||
# and recalculate the starting append position.
|
||||
result = handler.client.upload_part(
|
||||
fp, upload_id, part_number, SizedFileAdapter(fileobj, num_to_upload)
|
||||
)
|
||||
parts.append(PartInfo(part_number, result.etag))
|
||||
|
||||
offset += num_to_upload
|
||||
part_number += 1
|
||||
|
||||
headers = dict()
|
||||
handler.client.complete_multipart_upload(fp, upload_id, parts, headers=headers)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -683,7 +756,7 @@ class LocalClient(StorageClient):
|
|||
@staticmethod
|
||||
def get_fns(folder):
|
||||
if not os.path.exists(folder):
|
||||
if gpc.is_rank_for_log():
|
||||
if is_rank_for_log():
|
||||
logger.warning(f"'{folder}' not found!")
|
||||
return None
|
||||
else:
|
||||
|
@ -815,6 +888,23 @@ def check_tmp_folder_accessibility(tmp_local_folder: str):
|
|||
raise RuntimeError(error_str)
|
||||
|
||||
|
||||
class SingletonMeta(type):
|
||||
"""
|
||||
Singleton Meta.
|
||||
"""
|
||||
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super().__call__(*args, **kwargs)
|
||||
else:
|
||||
assert (
|
||||
len(args) == 0 and len(kwargs) == 0
|
||||
), f"{cls.__name__} is a singleton class and a instance has been created."
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class StorageManager(metaclass=SingletonMeta):
|
||||
"""
|
||||
Storage Manager for saving or loading checkpoint.
|
||||
|
@ -898,7 +988,7 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
or "HTTP_PROXY" in os.environ
|
||||
or "HTTPS_PROXY" in os.environ
|
||||
):
|
||||
if not self.has_warning and gpc.is_rank_for_log():
|
||||
if not self.has_warning and is_rank_for_log():
|
||||
logger.warning(
|
||||
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
|
||||
the proxy may make boto3 unavailable or affect performance."
|
||||
|
@ -917,7 +1007,7 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
or "HTTP_PROXY" in os.environ
|
||||
or "HTTPS_PROXY" in os.environ
|
||||
):
|
||||
if not self.has_warning and gpc.is_rank_for_log():
|
||||
if not self.has_warning and is_rank_for_log():
|
||||
logger.warning(
|
||||
"HTTP/HTTPS proxy is detected when using volc, incorrectly setting \
|
||||
the proxy may make volc unavailable or affect performance."
|
||||
|
@ -936,7 +1026,7 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
or "HTTP_PROXY" in os.environ
|
||||
or "HTTPS_PROXY" in os.environ
|
||||
):
|
||||
if not self.has_warning and gpc.is_rank_for_log():
|
||||
if not self.has_warning and is_rank_for_log():
|
||||
logger.warning(
|
||||
"HTTP/HTTPS proxy is detected when using oss2, incorrectly setting \
|
||||
the proxy may make oss2 unavailable or affect performance."
|
||||
|
@ -1082,7 +1172,7 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
self._to_be_del_files.clear()
|
||||
self.async_task_peeding = False
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
if is_rank_for_log():
|
||||
self.upload_count += 1
|
||||
if self.async_mode and self.latest_save_folder:
|
||||
self.save(
|
||||
|
|
|
@ -9,7 +9,7 @@ from internlm.core.context import global_context as gpc
|
|||
from internlm.core.context.parallel_context import Config
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||
from internlm.train.utils import create_param_groups
|
||||
from internlm.utils.common import SingletonMeta
|
||||
from internlm.utils.storage_manager import SingletonMeta
|
||||
|
||||
OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None)
|
||||
OSS_IP = os.environ.get("OSS_IP", None)
|
||||
|
|
|
@ -100,7 +100,6 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg
|
|||
init_storage_manager,
|
||||
llm_load,
|
||||
llm_save,
|
||||
wait_async_upload_finish,
|
||||
)
|
||||
|
||||
ckpt_config = Config(ckpt_config)
|
||||
|
@ -118,8 +117,6 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg
|
|||
tobj = torch.rand(64, 64)
|
||||
save_fn = os.path.join(ckpt_config.save_folder, "test.pt")
|
||||
llm_save(save_fn, tobj)
|
||||
if ckpt_config.test_id == 0:
|
||||
wait_async_upload_finish()
|
||||
check_folder(save_fn)
|
||||
assert get_fns(ckpt_config.save_folder)[0] == "test.pt"
|
||||
load_obj = llm_load(save_fn, map_location="cpu")
|
||||
|
|
Loading…
Reference in New Issue