optimize model ckpt and reduce checkpointing overhead

pull/543/head
zigzagcai 2023-12-15 11:42:04 +08:00
parent bbb5651582
commit 51dd3da03e
2 changed files with 23 additions and 6 deletions

View File

@ -859,8 +859,10 @@ class CheckpointManager:
self.async_upload = get_config_value(ckpt_config, "async_upload", False)
if self.save_ckpt_folder.startswith("volc:") or self.save_ckpt_folder.startswith("oss2:"):
use_processpool = True
# initialization storage manager
init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload)
init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload, use_processpool)
self.feishu_address = feishu_address
self.storage_manager = get_storage_manager()

View File

@ -1,6 +1,14 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import multiprocessing
import dill
dill.Pickler.dumps, dill.Pickler.loads = dill.dumps, dill.loads
multiprocessing.reduction.ForkingPickler = dill.Pickler
multiprocessing.reduction.dump = dill.dump
import asyncio
import concurrent.futures
import hashlib
@ -976,7 +984,9 @@ class StorageManager(metaclass=SingletonMeta):
}
CLI_DICT = {}
def __init__(self, enable_save, tmp_local_folder="/dev/shm/test/", async_mode=True, n_async_workers=8) -> None:
def __init__(
self, enable_save, tmp_local_folder="/dev/shm/test/", async_mode=True, use_processpool=False, n_async_workers=8
) -> None:
self._exception_list = []
self._to_be_del_files = []
self._async_stack = []
@ -985,14 +995,18 @@ class StorageManager(metaclass=SingletonMeta):
self.async_mode = async_mode
self.has_warning = False
self._async_loop = None
self._thread_pool = None
self._executor_pool = None
self.latest_save_folder = None
self.latest_save_step = 0
self.async_task_peeding = False
if enable_save and self.async_mode:
self._async_loop = asyncio.new_event_loop()
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers)
if use_processpool:
self._executor_pool = concurrent.futures.ProcessPoolExecutor(max_workers=n_async_workers)
else:
self._executor_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers)
check_tmp_folder_accessibility(os.path.dirname(self.tmp_local_folder))
@ -1196,7 +1210,7 @@ class StorageManager(metaclass=SingletonMeta):
"""
if not self._async_loop:
raise RuntimeError("Event loop was not initialized, please call this function in async or parallel mode")
t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs)
t = self._async_loop.run_in_executor(self._executor_pool, fn, *args, **kwargs)
self._async_stack.append(t)
def wait(self) -> bool:
@ -1242,12 +1256,13 @@ class StorageManager(metaclass=SingletonMeta):
storage_manager: StorageManager = None
def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload):
def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload, use_processpool=False):
global storage_manager
storage_manager = StorageManager(
enable_save_ckpt,
tmp_local_folder=async_upload_tmp_folder,
async_mode=async_upload,
use_processpool=use_processpool,
)