diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 2736532..82a4a21 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -178,7 +178,8 @@ def args_sanity_check(): else: if ckpt.async_upload: assert "save_ckpt_folder" in ckpt - if "boto3:" not in ckpt.save_ckpt_folder: + prefix_list = ["boto3:", "volc:", "oss2:"] + if not any(ckpt.save_ckpt_folder.startswith(prefix) for prefix in prefix_list): if gpc.is_rank_for_log(): logger.warning( "Storing ckpt on file system does not support asynchronous storage, will use sync save!" diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index d16db0c..87a303c 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -720,6 +720,7 @@ class CheckpointManager: self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"]) self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"]) + torch.distributed.barrier() # test storage setting is ok. if self.enable_save_ckpt: self.try_ping_storage() @@ -1016,7 +1017,7 @@ now step_count is {train_state.step_count}", self.storage_manager.latest_save_step = step def try_ping_storage(self): - if gpc.get_global_rank() % 8 == 0: + if gpc.is_rank_for_log(): buff = torch.ones((1, 64, 64), dtype=torch.bfloat16) test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping") self.storage_manager.save(test_fn, buff) diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index c76b570..14c620b 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -18,10 +18,6 @@ from typing import Any, Awaitable, Callable, Dict, List, Union import torch import torch.distributed as dist -from internlm.core.context import global_context as gpc -from internlm.utils.common import SingletonMeta -from internlm.utils.logger import get_logger - try: import boto3 import botocore @@ -30,16 +26,38 @@ except ImportError: try: import tos + from tos.utils import SizeAdapter except ImportError: pass try: import oss2 + from oss2 import SizedFileAdapter, determine_part_size + from oss2.models import PartInfo except ImportError: pass -logger = get_logger(__file__) +class Logger: + "Dummy logger" + + def info(self, mesage: str): + print(f"Info: {mesage}", flush=True) + + def warning(self, mesage: str): + print(f"Warning: {mesage}", flush=True) + + def error(self, mesage: str): + print(f"Error: {mesage}", flush=True) + + +try: + from internlm.utils.logger import get_logger + + logger = get_logger(__file__) +except ImportError: + logger = Logger() + boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)") volc_url_re = re.compile(r"^(.*?)\.(.*)$") @@ -66,6 +84,12 @@ def llm_save(save_path: str, saved_obj: Any, **kwargs): storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs) +def is_rank_for_log(): + if dist.is_initialized(): + return dist.get_rank() % 8 == 0 + return True + + class StorageClient: """ StorageClient as a client for s3 storage access. @@ -267,21 +291,21 @@ def compute_file_md5_by_chunk(file_name: str): def try_get_storage_backend(path: str): if path.startswith("s3:"): - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.") return "boto3", path elif path.startswith("vc:"): - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of volc.") return "volc", path elif path.startswith("ali:"): - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of ali.") return "oss2", path else: sre = path.split(":", maxsplit=1) if len(sre) == 1: - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.") return "local", sre[0] else: @@ -399,7 +423,7 @@ class Boto3Client(StorageClient): folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) return list(set(folder_name_list)) else: - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"'{fp}' not found!") return None @@ -530,14 +554,41 @@ class VolcClient(StorageClient): return list(set(folder_name_list)) else: - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"'{fp}' not found!") return None @staticmethod def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str): try: - handler.client.put_object_from_file(bucket_name, fp, local_nvme_path) + total_size = os.path.getsize(local_nvme_path) + part_size = 5 * 1024 * 1024 + + multi_result = handler.client.create_multipart_upload(bucket_name, fp) + + upload_id = multi_result.upload_id + parts = [] + + # Upload shard data + with open(local_nvme_path, "rb") as f: + part_number = 1 + offset = 0 + while offset < total_size: + num_to_upload = min(part_size, total_size - offset) + out = handler.client.upload_part( + bucket_name, + fp, + upload_id, + part_number, + content=SizeAdapter(f, num_to_upload, init_offset=offset), + ) + parts.append(out) + offset += num_to_upload + part_number += 1 + + # Complete the multipart upload task + handler.client.complete_multipart_upload(bucket_name, fp, upload_id, parts) + except handler.handler.exceptions.TosClientError as exc: raise RuntimeError( f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}" @@ -548,6 +599,8 @@ class VolcClient(StorageClient): f"error with request id: {exec.request_id}", f"error with message: {exec.message}", f"error with http code: {exec.status_code}", + f"error with ec: {exec.ec}", + f"error with request url: {exec.request_url}", ) from exc except Exception as e: raise e @@ -570,10 +623,10 @@ class AliClient(StorageClient): """Ali object/file storage management class Args: - access_key (str): Ali access key ID. + access_key (str): Ali access key ID.s secret_key (str): Ali secret access key. endpoint (str): Ali tos endpoint. - region (str): Ali tos region. + bucket_name (str): Ali tos bucket_name. """ super().__init__(oss2) @@ -634,14 +687,34 @@ class AliClient(StorageClient): return list(set(folder_name_list)) else: - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"'{fp}' not found!") return None @staticmethod def async_upload_fileobj(handler, fp: str, local_nvme_path: str): try: - handler.client.put_object_from_file(fp, local_nvme_path) + total_size = os.path.getsize(local_nvme_path) + part_size = determine_part_size(total_size, preferred_size=5 * 1024 * 1024) + upload_id = handler.client.init_multipart_upload(fp).upload_id + parts = [] + with open(local_nvme_path, "rb") as fileobj: + part_number = 1 + offset = 0 + while offset < total_size: + num_to_upload = min(part_size, total_size - offset) + # Calling the SizedFileAdapter method will generate a new file object + # and recalculate the starting append position. + result = handler.client.upload_part( + fp, upload_id, part_number, SizedFileAdapter(fileobj, num_to_upload) + ) + parts.append(PartInfo(part_number, result.etag)) + + offset += num_to_upload + part_number += 1 + + headers = dict() + handler.client.complete_multipart_upload(fp, upload_id, parts, headers=headers) except Exception as e: raise e @@ -683,7 +756,7 @@ class LocalClient(StorageClient): @staticmethod def get_fns(folder): if not os.path.exists(folder): - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"'{folder}' not found!") return None else: @@ -815,6 +888,23 @@ def check_tmp_folder_accessibility(tmp_local_folder: str): raise RuntimeError(error_str) +class SingletonMeta(type): + """ + Singleton Meta. + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + else: + assert ( + len(args) == 0 and len(kwargs) == 0 + ), f"{cls.__name__} is a singleton class and a instance has been created." + return cls._instances[cls] + + class StorageManager(metaclass=SingletonMeta): """ Storage Manager for saving or loading checkpoint. @@ -898,7 +988,7 @@ class StorageManager(metaclass=SingletonMeta): or "HTTP_PROXY" in os.environ or "HTTPS_PROXY" in os.environ ): - if not self.has_warning and gpc.is_rank_for_log(): + if not self.has_warning and is_rank_for_log(): logger.warning( "HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \ the proxy may make boto3 unavailable or affect performance." @@ -917,7 +1007,7 @@ class StorageManager(metaclass=SingletonMeta): or "HTTP_PROXY" in os.environ or "HTTPS_PROXY" in os.environ ): - if not self.has_warning and gpc.is_rank_for_log(): + if not self.has_warning and is_rank_for_log(): logger.warning( "HTTP/HTTPS proxy is detected when using volc, incorrectly setting \ the proxy may make volc unavailable or affect performance." @@ -936,7 +1026,7 @@ class StorageManager(metaclass=SingletonMeta): or "HTTP_PROXY" in os.environ or "HTTPS_PROXY" in os.environ ): - if not self.has_warning and gpc.is_rank_for_log(): + if not self.has_warning and is_rank_for_log(): logger.warning( "HTTP/HTTPS proxy is detected when using oss2, incorrectly setting \ the proxy may make oss2 unavailable or affect performance." @@ -1082,7 +1172,7 @@ class StorageManager(metaclass=SingletonMeta): self._to_be_del_files.clear() self.async_task_peeding = False - if gpc.is_rank_for_log(): + if is_rank_for_log(): self.upload_count += 1 if self.async_mode and self.latest_save_folder: self.save( diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index d0f1455..6096156 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -9,7 +9,7 @@ from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer from internlm.train.utils import create_param_groups -from internlm.utils.common import SingletonMeta +from internlm.utils.storage_manager import SingletonMeta OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None) OSS_IP = os.environ.get("OSS_IP", None) diff --git a/tests/test_utils/test_storage_manager.py b/tests/test_utils/test_storage_manager.py index e102ca1..9454a83 100644 --- a/tests/test_utils/test_storage_manager.py +++ b/tests/test_utils/test_storage_manager.py @@ -100,7 +100,6 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg init_storage_manager, llm_load, llm_save, - wait_async_upload_finish, ) ckpt_config = Config(ckpt_config) @@ -118,8 +117,6 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg tobj = torch.rand(64, 64) save_fn = os.path.join(ckpt_config.save_folder, "test.pt") llm_save(save_fn, tobj) - if ckpt_config.test_id == 0: - wait_async_upload_finish() check_folder(save_fn) assert get_fns(ckpt_config.save_folder)[0] == "test.pt" load_obj = llm_load(save_fn, map_location="cpu")