From 3b7fb97e0491f4a2fc13f9e69a9ab71bb0cc2c59 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Fri, 1 Dec 2023 11:10:04 +0800 Subject: [PATCH] storage --- internlm/utils/model_checkpoint.py | 7 +- internlm/utils/storage_manager.py | 107 ++++++++++++++++------- tests/test_utils/common_fixture.py | 2 +- tests/test_utils/test_storage_manager.py | 3 - 4 files changed, 79 insertions(+), 40 deletions(-) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 2222ed4..87a303c 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -720,9 +720,10 @@ class CheckpointManager: self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"]) self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"]) + torch.distributed.barrier() # test storage setting is ok. - # if self.enable_save_ckpt: - # self.try_ping_storage() + if self.enable_save_ckpt: + self.try_ping_storage() def quit_signal_handler(self, train_state) -> bool: """ @@ -1016,7 +1017,7 @@ now step_count is {train_state.step_count}", self.storage_manager.latest_save_step = step def try_ping_storage(self): - if gpc.get_global_rank() % 8 == 0: + if gpc.is_rank_for_log(): buff = torch.ones((1, 64, 64), dtype=torch.bfloat16) test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping") self.storage_manager.save(test_fn, buff) diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 0b8c561..8a05637 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -18,10 +18,6 @@ from typing import Any, Awaitable, Callable, Dict, List, Union import torch import torch.distributed as dist -from internlm.core.context import global_context as gpc -from internlm.utils.common import SingletonMeta -from internlm.utils.logger import get_logger - try: import boto3 import botocore @@ -30,8 +26,7 @@ except ImportError: try: import tos - from tos import DataTransferType - from tos.utils import SizeAdapter, MergeProcess + from tos.utils import SizeAdapter except ImportError: pass @@ -43,7 +38,26 @@ except ImportError: pass -logger = get_logger(__file__) +class Logger: + "Dummy logger" + + def info(self, mesage: str): + print(f"Info: {mesage}", flush=True) + + def warning(self, mesage: str): + print(f"Warning: {mesage}", flush=True) + + def error(self, mesage: str): + print(f"Error: {mesage}", flush=True) + + +try: + from internlm.utils.logger import get_logger + + logger = get_logger(__file__) +except ImportError: + logger = Logger() + boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)") volc_url_re = re.compile(r"^(.*?)\.(.*)$") @@ -68,7 +82,13 @@ def llm_load(fp: str, **kwargs): def llm_save(save_path: str, saved_obj: Any, **kwargs): storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs) - + + +def is_rank_for_log(): + if dist.is_initialized(): + return dist.get_rank() % 8 == 0 + return True + class StorageClient: """ @@ -271,21 +291,21 @@ def compute_file_md5_by_chunk(file_name: str): def try_get_storage_backend(path: str): if path.startswith("s3:"): - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.") return "boto3", path elif path.startswith("vc:"): - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of volc.") return "volc", path elif path.startswith("ali:"): - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of ali.") return "oss2", path else: sre = path.split(":", maxsplit=1) if len(sre) == 1: - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.") return "local", sre[0] else: @@ -403,7 +423,7 @@ class Boto3Client(StorageClient): folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) return list(set(folder_name_list)) else: - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"'{fp}' not found!") return None @@ -534,7 +554,7 @@ class VolcClient(StorageClient): return list(set(folder_name_list)) else: - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"'{fp}' not found!") return None @@ -550,22 +570,25 @@ class VolcClient(StorageClient): parts = [] # 上传分片数据 - logger.info('Begin multipart upload of one file') - with open(local_nvme_path, 'rb') as f: + with open(local_nvme_path, "rb") as f: part_number = 1 offset = 0 while offset < total_size: num_to_upload = min(part_size, total_size - offset) - out = handler.client.upload_part(bucket_name, fp, upload_id, part_number, - content=SizeAdapter(f, num_to_upload, init_offset=offset)) + out = handler.client.upload_part( + bucket_name, + fp, + upload_id, + part_number, + content=SizeAdapter(f, num_to_upload, init_offset=offset), + ) parts.append(out) offset += num_to_upload part_number += 1 # 完成分片上传任务 handler.client.complete_multipart_upload(bucket_name, fp, upload_id, parts) - logger.info('Finish multipart upload of one file') - + except handler.handler.exceptions.TosClientError as exc: raise RuntimeError( f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}" @@ -600,10 +623,10 @@ class AliClient(StorageClient): """Ali object/file storage management class Args: - access_key (str): Ali access key ID. + access_key (str): Ali access key ID.s secret_key (str): Ali secret access key. endpoint (str): Ali tos endpoint. - region (str): Ali tos region. + bucket_name (str): Ali tos bucket_name. """ super().__init__(oss2) @@ -664,7 +687,7 @@ class AliClient(StorageClient): return list(set(folder_name_list)) else: - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"'{fp}' not found!") return None @@ -672,24 +695,25 @@ class AliClient(StorageClient): def async_upload_fileobj(handler, fp: str, local_nvme_path: str): try: # handler.client.put_object_from_file(fp, local_nvme_path) - + total_size = os.path.getsize(local_nvme_path) part_size = determine_part_size(total_size, preferred_size=5 * 1024 * 1024) upload_id = handler.client.init_multipart_upload(fp).upload_id parts = [] - with open(local_nvme_path, 'rb') as fileobj: + with open(local_nvme_path, "rb") as fileobj: part_number = 1 offset = 0 while offset < total_size: num_to_upload = min(part_size, total_size - offset) # 调用SizedFileAdapter(fileobj, size)方法会生成一个新的文件对象,重新计算起始追加位置。 - result = handler.client.upload_part(fp, upload_id, part_number, - SizedFileAdapter(fileobj, num_to_upload)) + result = handler.client.upload_part( + fp, upload_id, part_number, SizedFileAdapter(fileobj, num_to_upload) + ) parts.append(PartInfo(part_number, result.etag)) offset += num_to_upload part_number += 1 - + headers = dict() handler.client.complete_multipart_upload(fp, upload_id, parts, headers=headers) except Exception as e: @@ -733,7 +757,7 @@ class LocalClient(StorageClient): @staticmethod def get_fns(folder): if not os.path.exists(folder): - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"'{folder}' not found!") return None else: @@ -865,6 +889,23 @@ def check_tmp_folder_accessibility(tmp_local_folder: str): raise RuntimeError(error_str) +class SingletonMeta(type): + """ + Singleton Meta. + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + else: + assert ( + len(args) == 0 and len(kwargs) == 0 + ), f"{cls.__name__} is a singleton class and a instance has been created." + return cls._instances[cls] + + class StorageManager(metaclass=SingletonMeta): """ Storage Manager for saving or loading checkpoint. @@ -948,7 +989,7 @@ class StorageManager(metaclass=SingletonMeta): or "HTTP_PROXY" in os.environ or "HTTPS_PROXY" in os.environ ): - if not self.has_warning and gpc.is_rank_for_log(): + if not self.has_warning and is_rank_for_log(): logger.warning( "HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \ the proxy may make boto3 unavailable or affect performance." @@ -967,7 +1008,7 @@ class StorageManager(metaclass=SingletonMeta): or "HTTP_PROXY" in os.environ or "HTTPS_PROXY" in os.environ ): - if not self.has_warning and gpc.is_rank_for_log(): + if not self.has_warning and is_rank_for_log(): logger.warning( "HTTP/HTTPS proxy is detected when using volc, incorrectly setting \ the proxy may make volc unavailable or affect performance." @@ -986,7 +1027,7 @@ class StorageManager(metaclass=SingletonMeta): or "HTTP_PROXY" in os.environ or "HTTPS_PROXY" in os.environ ): - if not self.has_warning and gpc.is_rank_for_log(): + if not self.has_warning and is_rank_for_log(): logger.warning( "HTTP/HTTPS proxy is detected when using oss2, incorrectly setting \ the proxy may make oss2 unavailable or affect performance." @@ -1132,7 +1173,7 @@ class StorageManager(metaclass=SingletonMeta): self._to_be_del_files.clear() self.async_task_peeding = False - if gpc.is_rank_for_log(): + if is_rank_for_log(): self.upload_count += 1 if self.async_mode and self.latest_save_folder: self.save( diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index 56e7b21..746f43e 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -8,7 +8,7 @@ import torch from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer -from internlm.utils.common import SingletonMeta +from internlm.utils.storage_manager import SingletonMeta OSS_NAME = os.environ.get("OSS_BUCKET_NAME") OSS_IP = os.environ.get("OSS_IP") diff --git a/tests/test_utils/test_storage_manager.py b/tests/test_utils/test_storage_manager.py index e102ca1..9454a83 100644 --- a/tests/test_utils/test_storage_manager.py +++ b/tests/test_utils/test_storage_manager.py @@ -100,7 +100,6 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg init_storage_manager, llm_load, llm_save, - wait_async_upload_finish, ) ckpt_config = Config(ckpt_config) @@ -118,8 +117,6 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg tobj = torch.rand(64, 64) save_fn = os.path.join(ckpt_config.save_folder, "test.pt") llm_save(save_fn, tobj) - if ckpt_config.test_id == 0: - wait_async_upload_finish() check_folder(save_fn) assert get_fns(ckpt_config.save_folder)[0] == "test.pt" load_obj = llm_load(save_fn, map_location="cpu")