From 51dd3da03ef2f3799ad98d13697ae09ae9b28267 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Fri, 15 Dec 2023 11:42:04 +0800 Subject: [PATCH] optimize model ckpt and reduce checkpointing overhead --- internlm/utils/model_checkpoint.py | 4 +++- internlm/utils/storage_manager.py | 25 ++++++++++++++++++++----- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 234944c..f912f52 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -859,8 +859,10 @@ class CheckpointManager: self.async_upload = get_config_value(ckpt_config, "async_upload", False) + if self.save_ckpt_folder.startswith("volc:") or self.save_ckpt_folder.startswith("oss2:"): + use_processpool = True # initialization storage manager - init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload) + init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload, use_processpool) self.feishu_address = feishu_address self.storage_manager = get_storage_manager() diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 53a4e37..8afcbd2 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -1,6 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import multiprocessing + +import dill + +dill.Pickler.dumps, dill.Pickler.loads = dill.dumps, dill.loads +multiprocessing.reduction.ForkingPickler = dill.Pickler +multiprocessing.reduction.dump = dill.dump + import asyncio import concurrent.futures import hashlib @@ -976,7 +984,9 @@ class StorageManager(metaclass=SingletonMeta): } CLI_DICT = {} - def __init__(self, enable_save, tmp_local_folder="/dev/shm/test/", async_mode=True, n_async_workers=8) -> None: + def __init__( + self, enable_save, tmp_local_folder="/dev/shm/test/", async_mode=True, use_processpool=False, n_async_workers=8 + ) -> None: self._exception_list = [] self._to_be_del_files = [] self._async_stack = [] @@ -985,14 +995,18 @@ class StorageManager(metaclass=SingletonMeta): self.async_mode = async_mode self.has_warning = False self._async_loop = None - self._thread_pool = None + self._executor_pool = None self.latest_save_folder = None self.latest_save_step = 0 self.async_task_peeding = False if enable_save and self.async_mode: self._async_loop = asyncio.new_event_loop() - self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers) + + if use_processpool: + self._executor_pool = concurrent.futures.ProcessPoolExecutor(max_workers=n_async_workers) + else: + self._executor_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers) check_tmp_folder_accessibility(os.path.dirname(self.tmp_local_folder)) @@ -1196,7 +1210,7 @@ class StorageManager(metaclass=SingletonMeta): """ if not self._async_loop: raise RuntimeError("Event loop was not initialized, please call this function in async or parallel mode") - t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs) + t = self._async_loop.run_in_executor(self._executor_pool, fn, *args, **kwargs) self._async_stack.append(t) def wait(self) -> bool: @@ -1242,12 +1256,13 @@ class StorageManager(metaclass=SingletonMeta): storage_manager: StorageManager = None -def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload): +def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload, use_processpool=False): global storage_manager storage_manager = StorageManager( enable_save_ckpt, tmp_local_folder=async_upload_tmp_folder, async_mode=async_upload, + use_processpool=use_processpool, )