From 06cdcc3654f7c81cbd6fb490cdf74896ebbc6f97 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Wed, 29 Nov 2023 11:08:40 +0800 Subject: [PATCH] upload --- internlm/initialize/launch.py | 3 ++- internlm/utils/model_checkpoint.py | 4 +-- internlm/utils/storage_manager.py | 41 +++++++++++++++++++++--------- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e96d2d9..c8c8c4a 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -178,7 +178,8 @@ def args_sanity_check(): else: if ckpt.async_upload: assert "save_ckpt_folder" in ckpt - if "boto3:" not in ckpt.save_ckpt_folder: + prefix_list = ["boto3:", "volc:", "oss2:"] + if not any(ckpt.save_ckpt_folder.startswith(prefix) for prefix in prefix_list): if gpc.is_rank_for_log(): logger.warning( "Storing ckpt on file system does not support asynchronous storage, will use sync save!" diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index d16db0c..2222ed4 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -721,8 +721,8 @@ class CheckpointManager: self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"]) # test storage setting is ok. - if self.enable_save_ckpt: - self.try_ping_storage() + # if self.enable_save_ckpt: + # self.try_ping_storage() def quit_signal_handler(self, train_state) -> bool: """ diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 86adaaf..0b8c561 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -37,6 +37,8 @@ except ImportError: try: import oss2 + from oss2 import SizedFileAdapter, determine_part_size + from oss2.models import PartInfo except ImportError: pass @@ -66,13 +68,7 @@ def llm_load(fp: str, **kwargs): def llm_save(save_path: str, saved_obj: Any, **kwargs): storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs) - - -def percentage(consumed_bytes: int, total_bytes: int, rw_once_bytes: int, type: DataTransferType): - if total_bytes and gpc.is_rank_for_log(): - rate = int(100 * float(consumed_bytes) / float(total_bytes)) - logger.info(f"rate:{rate}, consumed_bytes:{consumed_bytes},total_bytes{total_bytes}, rw_once_bytes:{rw_once_bytes}, type:{type}") - + class StorageClient: """ @@ -547,28 +543,29 @@ class VolcClient(StorageClient): try: total_size = os.path.getsize(local_nvme_path) part_size = 5 * 1024 * 1024 - - data_transfer_listener = MergeProcess(percentage, total_size, (total_size + part_size - 1) // part_size, 0) + multi_result = handler.client.create_multipart_upload(bucket_name, fp) upload_id = multi_result.upload_id parts = [] # 上传分片数据 + logger.info('Begin multipart upload of one file') with open(local_nvme_path, 'rb') as f: part_number = 1 offset = 0 while offset < total_size: num_to_upload = min(part_size, total_size - offset) out = handler.client.upload_part(bucket_name, fp, upload_id, part_number, - content=SizeAdapter(f, num_to_upload, init_offset=offset), - data_transfer_listener=data_transfer_listener) + content=SizeAdapter(f, num_to_upload, init_offset=offset)) parts.append(out) offset += num_to_upload part_number += 1 # 完成分片上传任务 handler.client.complete_multipart_upload(bucket_name, fp, upload_id, parts) + logger.info('Finish multipart upload of one file') + except handler.handler.exceptions.TosClientError as exc: raise RuntimeError( f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}" @@ -674,7 +671,27 @@ class AliClient(StorageClient): @staticmethod def async_upload_fileobj(handler, fp: str, local_nvme_path: str): try: - handler.client.put_object_from_file(fp, local_nvme_path) + # handler.client.put_object_from_file(fp, local_nvme_path) + + total_size = os.path.getsize(local_nvme_path) + part_size = determine_part_size(total_size, preferred_size=5 * 1024 * 1024) + upload_id = handler.client.init_multipart_upload(fp).upload_id + parts = [] + with open(local_nvme_path, 'rb') as fileobj: + part_number = 1 + offset = 0 + while offset < total_size: + num_to_upload = min(part_size, total_size - offset) + # 调用SizedFileAdapter(fileobj, size)方法会生成一个新的文件对象,重新计算起始追加位置。 + result = handler.client.upload_part(fp, upload_id, part_number, + SizedFileAdapter(fileobj, num_to_upload)) + parts.append(PartInfo(part_number, result.etag)) + + offset += num_to_upload + part_number += 1 + + headers = dict() + handler.client.complete_multipart_upload(fp, upload_id, parts, headers=headers) except Exception as e: raise e