mirror of https://github.com/InternLM/InternLM
optimize model ckpt and reduce checkpointing overhead
parent
bbb5651582
commit
51dd3da03e
|
@ -859,8 +859,10 @@ class CheckpointManager:
|
|||
|
||||
self.async_upload = get_config_value(ckpt_config, "async_upload", False)
|
||||
|
||||
if self.save_ckpt_folder.startswith("volc:") or self.save_ckpt_folder.startswith("oss2:"):
|
||||
use_processpool = True
|
||||
# initialization storage manager
|
||||
init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload)
|
||||
init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload, use_processpool)
|
||||
|
||||
self.feishu_address = feishu_address
|
||||
self.storage_manager = get_storage_manager()
|
||||
|
|
|
@ -1,6 +1,14 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import multiprocessing
|
||||
|
||||
import dill
|
||||
|
||||
dill.Pickler.dumps, dill.Pickler.loads = dill.dumps, dill.loads
|
||||
multiprocessing.reduction.ForkingPickler = dill.Pickler
|
||||
multiprocessing.reduction.dump = dill.dump
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
|
@ -976,7 +984,9 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
}
|
||||
CLI_DICT = {}
|
||||
|
||||
def __init__(self, enable_save, tmp_local_folder="/dev/shm/test/", async_mode=True, n_async_workers=8) -> None:
|
||||
def __init__(
|
||||
self, enable_save, tmp_local_folder="/dev/shm/test/", async_mode=True, use_processpool=False, n_async_workers=8
|
||||
) -> None:
|
||||
self._exception_list = []
|
||||
self._to_be_del_files = []
|
||||
self._async_stack = []
|
||||
|
@ -985,14 +995,18 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
self.async_mode = async_mode
|
||||
self.has_warning = False
|
||||
self._async_loop = None
|
||||
self._thread_pool = None
|
||||
self._executor_pool = None
|
||||
self.latest_save_folder = None
|
||||
self.latest_save_step = 0
|
||||
self.async_task_peeding = False
|
||||
|
||||
if enable_save and self.async_mode:
|
||||
self._async_loop = asyncio.new_event_loop()
|
||||
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers)
|
||||
|
||||
if use_processpool:
|
||||
self._executor_pool = concurrent.futures.ProcessPoolExecutor(max_workers=n_async_workers)
|
||||
else:
|
||||
self._executor_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers)
|
||||
|
||||
check_tmp_folder_accessibility(os.path.dirname(self.tmp_local_folder))
|
||||
|
||||
|
@ -1196,7 +1210,7 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
"""
|
||||
if not self._async_loop:
|
||||
raise RuntimeError("Event loop was not initialized, please call this function in async or parallel mode")
|
||||
t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs)
|
||||
t = self._async_loop.run_in_executor(self._executor_pool, fn, *args, **kwargs)
|
||||
self._async_stack.append(t)
|
||||
|
||||
def wait(self) -> bool:
|
||||
|
@ -1242,12 +1256,13 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
storage_manager: StorageManager = None
|
||||
|
||||
|
||||
def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload):
|
||||
def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload, use_processpool=False):
|
||||
global storage_manager
|
||||
storage_manager = StorageManager(
|
||||
enable_save_ckpt,
|
||||
tmp_local_folder=async_upload_tmp_folder,
|
||||
async_mode=async_upload,
|
||||
use_processpool=use_processpool,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue