mirror of https://github.com/InternLM/InternLM
feat(storage): use multipart upload when using oss (#520)
* multipart upload * upload * storage * storage * storage * storagepull/526/head
parent
66bffffe5c
commit
1738bee002
|
@ -178,7 +178,8 @@ def args_sanity_check():
|
||||||
else:
|
else:
|
||||||
if ckpt.async_upload:
|
if ckpt.async_upload:
|
||||||
assert "save_ckpt_folder" in ckpt
|
assert "save_ckpt_folder" in ckpt
|
||||||
if "boto3:" not in ckpt.save_ckpt_folder:
|
prefix_list = ["boto3:", "volc:", "oss2:"]
|
||||||
|
if not any(ckpt.save_ckpt_folder.startswith(prefix) for prefix in prefix_list):
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Storing ckpt on file system does not support asynchronous storage, will use sync save!"
|
"Storing ckpt on file system does not support asynchronous storage, will use sync save!"
|
||||||
|
|
|
@ -720,6 +720,7 @@ class CheckpointManager:
|
||||||
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
|
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
|
||||||
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"])
|
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"])
|
||||||
|
|
||||||
|
torch.distributed.barrier()
|
||||||
# test storage setting is ok.
|
# test storage setting is ok.
|
||||||
if self.enable_save_ckpt:
|
if self.enable_save_ckpt:
|
||||||
self.try_ping_storage()
|
self.try_ping_storage()
|
||||||
|
@ -1016,7 +1017,7 @@ now step_count is {train_state.step_count}",
|
||||||
self.storage_manager.latest_save_step = step
|
self.storage_manager.latest_save_step = step
|
||||||
|
|
||||||
def try_ping_storage(self):
|
def try_ping_storage(self):
|
||||||
if gpc.get_global_rank() % 8 == 0:
|
if gpc.is_rank_for_log():
|
||||||
buff = torch.ones((1, 64, 64), dtype=torch.bfloat16)
|
buff = torch.ones((1, 64, 64), dtype=torch.bfloat16)
|
||||||
test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping")
|
test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping")
|
||||||
self.storage_manager.save(test_fn, buff)
|
self.storage_manager.save(test_fn, buff)
|
||||||
|
|
|
@ -18,10 +18,6 @@ from typing import Any, Awaitable, Callable, Dict, List, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from internlm.core.context import global_context as gpc
|
|
||||||
from internlm.utils.common import SingletonMeta
|
|
||||||
from internlm.utils.logger import get_logger
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import boto3
|
import boto3
|
||||||
import botocore
|
import botocore
|
||||||
|
@ -30,16 +26,38 @@ except ImportError:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import tos
|
import tos
|
||||||
|
from tos.utils import SizeAdapter
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import oss2
|
import oss2
|
||||||
|
from oss2 import SizedFileAdapter, determine_part_size
|
||||||
|
from oss2.models import PartInfo
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
class Logger:
|
||||||
|
"Dummy logger"
|
||||||
|
|
||||||
|
def info(self, mesage: str):
|
||||||
|
print(f"Info: {mesage}", flush=True)
|
||||||
|
|
||||||
|
def warning(self, mesage: str):
|
||||||
|
print(f"Warning: {mesage}", flush=True)
|
||||||
|
|
||||||
|
def error(self, mesage: str):
|
||||||
|
print(f"Error: {mesage}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__file__)
|
||||||
|
except ImportError:
|
||||||
|
logger = Logger()
|
||||||
|
|
||||||
|
|
||||||
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
|
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
|
||||||
volc_url_re = re.compile(r"^(.*?)\.(.*)$")
|
volc_url_re = re.compile(r"^(.*?)\.(.*)$")
|
||||||
|
@ -66,6 +84,12 @@ def llm_save(save_path: str, saved_obj: Any, **kwargs):
|
||||||
storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs)
|
storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def is_rank_for_log():
|
||||||
|
if dist.is_initialized():
|
||||||
|
return dist.get_rank() % 8 == 0
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class StorageClient:
|
class StorageClient:
|
||||||
"""
|
"""
|
||||||
StorageClient as a client for s3 storage access.
|
StorageClient as a client for s3 storage access.
|
||||||
|
@ -267,21 +291,21 @@ def compute_file_md5_by_chunk(file_name: str):
|
||||||
|
|
||||||
def try_get_storage_backend(path: str):
|
def try_get_storage_backend(path: str):
|
||||||
if path.startswith("s3:"):
|
if path.startswith("s3:"):
|
||||||
if gpc.is_rank_for_log():
|
if is_rank_for_log():
|
||||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
|
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
|
||||||
return "boto3", path
|
return "boto3", path
|
||||||
elif path.startswith("vc:"):
|
elif path.startswith("vc:"):
|
||||||
if gpc.is_rank_for_log():
|
if is_rank_for_log():
|
||||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of volc.")
|
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of volc.")
|
||||||
return "volc", path
|
return "volc", path
|
||||||
elif path.startswith("ali:"):
|
elif path.startswith("ali:"):
|
||||||
if gpc.is_rank_for_log():
|
if is_rank_for_log():
|
||||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of ali.")
|
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of ali.")
|
||||||
return "oss2", path
|
return "oss2", path
|
||||||
else:
|
else:
|
||||||
sre = path.split(":", maxsplit=1)
|
sre = path.split(":", maxsplit=1)
|
||||||
if len(sre) == 1:
|
if len(sre) == 1:
|
||||||
if gpc.is_rank_for_log():
|
if is_rank_for_log():
|
||||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.")
|
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.")
|
||||||
return "local", sre[0]
|
return "local", sre[0]
|
||||||
else:
|
else:
|
||||||
|
@ -399,7 +423,7 @@ class Boto3Client(StorageClient):
|
||||||
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
|
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
|
||||||
return list(set(folder_name_list))
|
return list(set(folder_name_list))
|
||||||
else:
|
else:
|
||||||
if gpc.is_rank_for_log():
|
if is_rank_for_log():
|
||||||
logger.warning(f"'{fp}' not found!")
|
logger.warning(f"'{fp}' not found!")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -530,14 +554,41 @@ class VolcClient(StorageClient):
|
||||||
return list(set(folder_name_list))
|
return list(set(folder_name_list))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if gpc.is_rank_for_log():
|
if is_rank_for_log():
|
||||||
logger.warning(f"'{fp}' not found!")
|
logger.warning(f"'{fp}' not found!")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
|
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
|
||||||
try:
|
try:
|
||||||
handler.client.put_object_from_file(bucket_name, fp, local_nvme_path)
|
total_size = os.path.getsize(local_nvme_path)
|
||||||
|
part_size = 5 * 1024 * 1024
|
||||||
|
|
||||||
|
multi_result = handler.client.create_multipart_upload(bucket_name, fp)
|
||||||
|
|
||||||
|
upload_id = multi_result.upload_id
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
# Upload shard data
|
||||||
|
with open(local_nvme_path, "rb") as f:
|
||||||
|
part_number = 1
|
||||||
|
offset = 0
|
||||||
|
while offset < total_size:
|
||||||
|
num_to_upload = min(part_size, total_size - offset)
|
||||||
|
out = handler.client.upload_part(
|
||||||
|
bucket_name,
|
||||||
|
fp,
|
||||||
|
upload_id,
|
||||||
|
part_number,
|
||||||
|
content=SizeAdapter(f, num_to_upload, init_offset=offset),
|
||||||
|
)
|
||||||
|
parts.append(out)
|
||||||
|
offset += num_to_upload
|
||||||
|
part_number += 1
|
||||||
|
|
||||||
|
# Complete the multipart upload task
|
||||||
|
handler.client.complete_multipart_upload(bucket_name, fp, upload_id, parts)
|
||||||
|
|
||||||
except handler.handler.exceptions.TosClientError as exc:
|
except handler.handler.exceptions.TosClientError as exc:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}"
|
f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}"
|
||||||
|
@ -548,6 +599,8 @@ class VolcClient(StorageClient):
|
||||||
f"error with request id: {exec.request_id}",
|
f"error with request id: {exec.request_id}",
|
||||||
f"error with message: {exec.message}",
|
f"error with message: {exec.message}",
|
||||||
f"error with http code: {exec.status_code}",
|
f"error with http code: {exec.status_code}",
|
||||||
|
f"error with ec: {exec.ec}",
|
||||||
|
f"error with request url: {exec.request_url}",
|
||||||
) from exc
|
) from exc
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -570,10 +623,10 @@ class AliClient(StorageClient):
|
||||||
"""Ali object/file storage management class
|
"""Ali object/file storage management class
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
access_key (str): Ali access key ID.
|
access_key (str): Ali access key ID.s
|
||||||
secret_key (str): Ali secret access key.
|
secret_key (str): Ali secret access key.
|
||||||
endpoint (str): Ali tos endpoint.
|
endpoint (str): Ali tos endpoint.
|
||||||
region (str): Ali tos region.
|
bucket_name (str): Ali tos bucket_name.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__(oss2)
|
super().__init__(oss2)
|
||||||
|
@ -634,14 +687,34 @@ class AliClient(StorageClient):
|
||||||
return list(set(folder_name_list))
|
return list(set(folder_name_list))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if gpc.is_rank_for_log():
|
if is_rank_for_log():
|
||||||
logger.warning(f"'{fp}' not found!")
|
logger.warning(f"'{fp}' not found!")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def async_upload_fileobj(handler, fp: str, local_nvme_path: str):
|
def async_upload_fileobj(handler, fp: str, local_nvme_path: str):
|
||||||
try:
|
try:
|
||||||
handler.client.put_object_from_file(fp, local_nvme_path)
|
total_size = os.path.getsize(local_nvme_path)
|
||||||
|
part_size = determine_part_size(total_size, preferred_size=5 * 1024 * 1024)
|
||||||
|
upload_id = handler.client.init_multipart_upload(fp).upload_id
|
||||||
|
parts = []
|
||||||
|
with open(local_nvme_path, "rb") as fileobj:
|
||||||
|
part_number = 1
|
||||||
|
offset = 0
|
||||||
|
while offset < total_size:
|
||||||
|
num_to_upload = min(part_size, total_size - offset)
|
||||||
|
# Calling the SizedFileAdapter method will generate a new file object
|
||||||
|
# and recalculate the starting append position.
|
||||||
|
result = handler.client.upload_part(
|
||||||
|
fp, upload_id, part_number, SizedFileAdapter(fileobj, num_to_upload)
|
||||||
|
)
|
||||||
|
parts.append(PartInfo(part_number, result.etag))
|
||||||
|
|
||||||
|
offset += num_to_upload
|
||||||
|
part_number += 1
|
||||||
|
|
||||||
|
headers = dict()
|
||||||
|
handler.client.complete_multipart_upload(fp, upload_id, parts, headers=headers)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -683,7 +756,7 @@ class LocalClient(StorageClient):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fns(folder):
|
def get_fns(folder):
|
||||||
if not os.path.exists(folder):
|
if not os.path.exists(folder):
|
||||||
if gpc.is_rank_for_log():
|
if is_rank_for_log():
|
||||||
logger.warning(f"'{folder}' not found!")
|
logger.warning(f"'{folder}' not found!")
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
|
@ -815,6 +888,23 @@ def check_tmp_folder_accessibility(tmp_local_folder: str):
|
||||||
raise RuntimeError(error_str)
|
raise RuntimeError(error_str)
|
||||||
|
|
||||||
|
|
||||||
|
class SingletonMeta(type):
|
||||||
|
"""
|
||||||
|
Singleton Meta.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instances = {}
|
||||||
|
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = super().__call__(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
len(args) == 0 and len(kwargs) == 0
|
||||||
|
), f"{cls.__name__} is a singleton class and a instance has been created."
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
|
||||||
class StorageManager(metaclass=SingletonMeta):
|
class StorageManager(metaclass=SingletonMeta):
|
||||||
"""
|
"""
|
||||||
Storage Manager for saving or loading checkpoint.
|
Storage Manager for saving or loading checkpoint.
|
||||||
|
@ -898,7 +988,7 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
or "HTTP_PROXY" in os.environ
|
or "HTTP_PROXY" in os.environ
|
||||||
or "HTTPS_PROXY" in os.environ
|
or "HTTPS_PROXY" in os.environ
|
||||||
):
|
):
|
||||||
if not self.has_warning and gpc.is_rank_for_log():
|
if not self.has_warning and is_rank_for_log():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
|
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
|
||||||
the proxy may make boto3 unavailable or affect performance."
|
the proxy may make boto3 unavailable or affect performance."
|
||||||
|
@ -917,7 +1007,7 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
or "HTTP_PROXY" in os.environ
|
or "HTTP_PROXY" in os.environ
|
||||||
or "HTTPS_PROXY" in os.environ
|
or "HTTPS_PROXY" in os.environ
|
||||||
):
|
):
|
||||||
if not self.has_warning and gpc.is_rank_for_log():
|
if not self.has_warning and is_rank_for_log():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"HTTP/HTTPS proxy is detected when using volc, incorrectly setting \
|
"HTTP/HTTPS proxy is detected when using volc, incorrectly setting \
|
||||||
the proxy may make volc unavailable or affect performance."
|
the proxy may make volc unavailable or affect performance."
|
||||||
|
@ -936,7 +1026,7 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
or "HTTP_PROXY" in os.environ
|
or "HTTP_PROXY" in os.environ
|
||||||
or "HTTPS_PROXY" in os.environ
|
or "HTTPS_PROXY" in os.environ
|
||||||
):
|
):
|
||||||
if not self.has_warning and gpc.is_rank_for_log():
|
if not self.has_warning and is_rank_for_log():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"HTTP/HTTPS proxy is detected when using oss2, incorrectly setting \
|
"HTTP/HTTPS proxy is detected when using oss2, incorrectly setting \
|
||||||
the proxy may make oss2 unavailable or affect performance."
|
the proxy may make oss2 unavailable or affect performance."
|
||||||
|
@ -1082,7 +1172,7 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
self._to_be_del_files.clear()
|
self._to_be_del_files.clear()
|
||||||
self.async_task_peeding = False
|
self.async_task_peeding = False
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
if is_rank_for_log():
|
||||||
self.upload_count += 1
|
self.upload_count += 1
|
||||||
if self.async_mode and self.latest_save_folder:
|
if self.async_mode and self.latest_save_folder:
|
||||||
self.save(
|
self.save(
|
||||||
|
|
|
@ -9,7 +9,7 @@ from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.context.parallel_context import Config
|
from internlm.core.context.parallel_context import Config
|
||||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||||
from internlm.train.utils import create_param_groups
|
from internlm.train.utils import create_param_groups
|
||||||
from internlm.utils.common import SingletonMeta
|
from internlm.utils.storage_manager import SingletonMeta
|
||||||
|
|
||||||
OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None)
|
OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None)
|
||||||
OSS_IP = os.environ.get("OSS_IP", None)
|
OSS_IP = os.environ.get("OSS_IP", None)
|
||||||
|
|
|
@ -100,7 +100,6 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg
|
||||||
init_storage_manager,
|
init_storage_manager,
|
||||||
llm_load,
|
llm_load,
|
||||||
llm_save,
|
llm_save,
|
||||||
wait_async_upload_finish,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ckpt_config = Config(ckpt_config)
|
ckpt_config = Config(ckpt_config)
|
||||||
|
@ -118,8 +117,6 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg
|
||||||
tobj = torch.rand(64, 64)
|
tobj = torch.rand(64, 64)
|
||||||
save_fn = os.path.join(ckpt_config.save_folder, "test.pt")
|
save_fn = os.path.join(ckpt_config.save_folder, "test.pt")
|
||||||
llm_save(save_fn, tobj)
|
llm_save(save_fn, tobj)
|
||||||
if ckpt_config.test_id == 0:
|
|
||||||
wait_async_upload_finish()
|
|
||||||
check_folder(save_fn)
|
check_folder(save_fn)
|
||||||
assert get_fns(ckpt_config.save_folder)[0] == "test.pt"
|
assert get_fns(ckpt_config.save_folder)[0] == "test.pt"
|
||||||
load_obj = llm_load(save_fn, map_location="cpu")
|
load_obj = llm_load(save_fn, map_location="cpu")
|
||||||
|
|
Loading…
Reference in New Issue