feat(ckpt): add async upload and ckpt snapshot (#161)

* use fp16 in instruction (#80)

* delete torch_dtype of README's example code (#100)

* feat(ckpt): support async ckpt upload and ckpt snapshot

---------

Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com>
Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>
pull/189/head
Guoteng 2023-08-08 13:08:36 +08:00 committed by GitHub
parent ff0fa7659f
commit 29d27a6227
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 453 additions and 85 deletions

View File

@ -7,22 +7,29 @@ MLP_RATIO = 8 / 3
NUM_LAYER = 32
VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
# oss: 'boto3:s3://model_weights/XXX'
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
# Path to save training ckpt.
save_ckpt_folder=SAVE_CKPT_FOLDER,
# Path to continue training ckpt (load model weights and scheduler/context states).
# load_ckpt_folder=LOAD_CKPT_FOLDER,
# Path to initialize with given model weights.
# load_model_only_folder=MODEL_ONLY_FOLDER,
checkpoint_every=50,
# Wheter to load optimizer states when continuing training.
load_optimizer=True,
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
load_optimizer=True, # Wheter to load optimizer states when continuing training.
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)
TRAIN_FOLDER = "/path/to/dataset"

View File

@ -11,6 +11,7 @@ import torch
from internlm.core.context import Config
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import init_storage_manager
logger = get_logger(__file__)
@ -122,20 +123,44 @@ def args_sanity_check():
if "load_model_only_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("load_model_only_folder", None)
if "async_upload" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("async_upload", False)
else:
if gpc.config.ckpt.async_upload:
assert "save_ckpt_folder" in gpc.config.ckpt
if "boto3:" not in gpc.config.ckpt.save_ckpt_folder:
if gpc.is_rank_for_log():
logger.warning(
"Storing ckpt on file system does not support asynchronous storage, will use sync save!"
)
gpc.config.ckpt.async_upload = False
else:
if "async_upload_tmp_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
if "snapshot_ckpt_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder), "snapshot")
if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"):
gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2)
assert gpc.config.ckpt.oss_snapshot_freq > 0
assert not (
gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None
), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time."
gpc.config.ckpt._add_item(
"enable_ckpt", gpc.config.ckpt.save_ckpt_folder is not None and gpc.config.ckpt.checkpoint_every > 0
)
if "enable_save_ckpt" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("enable_save_ckpt", False)
if gpc.is_rank_for_log():
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_ckpt}")
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_save_ckpt}")
logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}")
logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}")
# initialization storage manager
init_storage_manager(gpc.config.ckpt)
# tensorboard writer config
if "enable_tb" not in gpc.config:
gpc.config._add_item("enable_tb", True)

View File

@ -4,6 +4,7 @@
import copy
import os
import time
from enum import Enum
from typing import Dict
import torch
@ -15,10 +16,22 @@ from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.storage_manager import get_fns, llm_load, llm_save
from internlm.utils.storage_manager import (
get_fns,
get_storage_manager,
llm_load,
llm_save,
)
logger = get_logger(__file__)
quit_signal_handler = None
class CheckpointType(Enum):
NORMAL_CHECKPOINT = 1
SNAPSHOT_CHECKPOINT = 2
def get_model_topology(model):
"""
@ -289,3 +302,77 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
if gpc.is_rank_for_log():
logger.info(f"reload load_scheduler:{lr_scheduler}")
class CheckpointSaveManager:
"""StorageManagerContext"""
def __init__(
self,
ckpt_config,
model,
optimizer,
lr_scheduler,
model_config,
) -> None:
"""
CheckpointSaveManager is used to decide when to store ckpt. If it is an asynchronous
upload mode, you must call wait_async_upload_finish at the end of the program to wait
for the asynchronous ckpt upload to complete.
Args:
ckpt_config (dict): model checkpoint config.
model (nn.module): model obj
optimizer (object): optimzier obj.
lr_scheduler (object): lr_scheduler obj.
model_config (dict): model config.
"""
self.enable_save_ckpt = ckpt_config.enable_save_ckpt
self.checkpoint_every = ckpt_config.checkpoint_every
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
self.storage_manager = get_storage_manager()
self.snapshot_counter = 0
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.model_config = model_config
def try_save_checkpoint(self, train_state):
if not self.enable_save_ckpt:
return
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
if train_state.step_count % self.checkpoint_every == 0:
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
if save_ckpts is False:
if quit_signal_handler is not None:
save_ckpts, save_type = quit_signal_handler(train_state)
if save_ckpts:
# Wait for the previous round of asynchronous upload storage to complete.
self.storage_manager.wait()
if save_type == CheckpointType.SNAPSHOT_CHECKPOINT:
# Snapshot number, with only two snapshots written alternately.
self.snapshot_counter = (self.snapshot_counter + 1) % 2
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
else:
save_ckpt_folder = self.save_ckpt_folder
save_checkpoint(
folder=save_ckpt_folder,
model=self.model,
optimizer=self.optimizer,
scheduler=self.lr_scheduler,
train_state=train_state,
model_config=self.model_config,
)
def wait_async_upload_finish(self):
"""wait for all checkpoint uploads to be completed"""
self.storage_manager.wait()
torch.distributed.barrier()

View File

@ -1,18 +1,26 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import asyncio
import concurrent.futures
import hashlib
import io
import os
import pickle
import re
import socket
from enum import Enum
from typing import Any, Dict, List, Union
import stat
from asyncio import InvalidStateError
from asyncio.tasks import ALL_COMPLETED
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Union
import boto3
import botocore
import torch
import torch.distributed as dist
from internlm.core.context import global_context as gpc
from internlm.utils.common import SingletonMeta
from internlm.utils.logger import get_logger
@ -41,10 +49,6 @@ def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
class CheckpointType(Enum):
NORMAL_CHECKPOINT = 1
class StorageClient:
"""
StorageClient as a client for s3 storage access.
@ -54,7 +58,7 @@ class StorageClient:
self.handler = handler
@staticmethod
def load(client, load_path: str, map_location):
def load(client, load_path: str, *args, **kwargs):
raise NotImplementedError
@staticmethod
@ -71,25 +75,51 @@ class StorageClient:
class Boto3MetaInfo:
def __init__(self, client: StorageClient, bucket_name: str, endpoint: str, file_path: str) -> None:
self.client = client
"""Boto3 meta info for save/load etc."""
def __init__(
self,
is_async,
handler: StorageClient,
bucket_name: str,
endpoint: str,
file_path: str,
async_upload_fn: callable,
local_nvme_path=None,
) -> None:
self.is_async = is_async
self.client = handler
self.bucket_name = bucket_name
self.endpoint = endpoint
self.file_path = file_path
self.async_upload_fn = async_upload_fn
self.local_nvme_path = local_nvme_path
def __str__(self) -> str:
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
local_nvme_path: {self.local_nvme_path}"
class LocalMetaInfo:
def __init__(self, client: StorageClient, dest_path: str) -> None:
self.client = client
"""Local meta info for save/load etc."""
def __init__(self, handler: StorageClient, dest_path: str) -> None:
self.is_async = False
self.client = handler
self.dest_path = dest_path
self.async_upload_fn = None
def unpack_meta(meta):
args = []
is_async = meta.is_async
for k, v in meta.__dict__.items():
if k == "endpoint":
if k in ("endpoint", "async_upload_fn", "is_async"):
continue
if not is_async and k in ("local_nvme_path",):
continue
args.append(v)
return args
@ -101,21 +131,6 @@ def compute_file_md5_by_chunk(file_name: str):
return hash_md5.hexdigest()
def get_boto3_meta(fp: str) -> Boto3MetaInfo:
assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url"
parts = fp.lstrip("s3://").split(os.path.sep)
match = boto3_url_re.match(parts[0])
assert match is not None, f"url '{fp}' is not a valid boto3 url"
bucket_name, endpoint = match.group(1), match.group(2)
endpoint = "http://" + endpoint + ":80"
return Boto3MetaInfo(None, bucket_name, endpoint, os.path.sep.join(parts[1:]))
def get_local_meta(fp: str) -> LocalMetaInfo:
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
return LocalMetaInfo(None, fp)
class Boto3Client(StorageClient):
"""
Boto3Client
@ -169,7 +184,9 @@ class Boto3Client(StorageClient):
)
@staticmethod
def sync_upload_fileobj(handler, bucket_name: str, fp: str, *args, saved_obj=None, **kwargs):
def sync_upload_fileobj(
handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs
): # pylint: disable=W0613
assert saved_obj is not None, "saved_obj is None!"
try:
with io.BytesIO() as f:
@ -182,7 +199,14 @@ class Boto3Client(StorageClient):
) from exc
@staticmethod
def load(handler, bucket_name: str, fp: str, *args, map_location="cpu", **kwargs) -> Dict:
def load(
handler,
bucket_name: str,
fp: str,
local_nvme_path: str, # pylint: disable=W0613
*args,
**kwargs,
) -> Dict:
"""
Args:
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
@ -191,7 +215,7 @@ class Boto3Client(StorageClient):
with io.BytesIO() as f:
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
f.seek(0)
states = torch.load(f, *args, map_location=map_location, **kwargs)
states = torch.load(f, *args, **kwargs)
except handler.botocore.exceptions.EndpointConnectionError as exc:
raise RuntimeError(
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
@ -199,15 +223,11 @@ class Boto3Client(StorageClient):
return states
@staticmethod
def assert_fp_exists(
handler,
bucket_name: str,
fp: str,
):
def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
@staticmethod
def get_fns(handler, bucket_name: str, fp: str):
def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613
"""
Ref: https://stackoverflow.com/questions/54314563/
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
@ -222,6 +242,22 @@ class Boto3Client(StorageClient):
folder_name_list.append(fp.rsplit("/", maxsplit=1)[1])
return folder_name_list
@staticmethod
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
try:
with open(local_nvme_path, "rb") as f:
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
except handler.botocore.exceptions.EndpointConnectionError as exc:
raise RuntimeError(
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
) from exc
except Exception as e:
raise e
@staticmethod
def delete_obj(handler, fp: str):
raise NotImplementedError("boto3 not support delete_obj")
class LocalClient(StorageClient):
"""
@ -241,11 +277,11 @@ class LocalClient(StorageClient):
torch.save(saved_obj, fp, *args, **kwargs)
@staticmethod
def load(handler, fp: str, *args, map_location="cpu", **kwargs):
def load(handler, fp: str, *args, **kwargs): # pylint: disable=W0613
assert isinstance(handler, LocalClient)
assert os.path.exists(fp), f"{fp} is not found!"
with open(fp, "rb") as f:
states = torch.load(f, map_location=map_location, *args, **kwargs)
states = torch.load(f, *args, **kwargs)
return states
@staticmethod
@ -267,9 +303,77 @@ class LocalClient(StorageClient):
os.remove(fp)
def get_tmp_file_name(tmp_local_folder: str, fp: str):
"""
It should be noted that all our temporary files will be stored in the same folder,
so the file name passed upstream must be unique.
"""
base_path = os.path.join(tmp_local_folder, fp.split("/")[-1])
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
pid = os.getpid()
# step = self.step_counter
return "-".join([base_path, current_time, str(pid)]) + ".tmpfile" # , str(step)
def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaInfo:
assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url"
parts = fp.lstrip("s3://").split(os.path.sep)
match = boto3_url_re.match(parts[0])
assert match is not None, f"url '{fp}' is not a valid boto3 url"
bucket_name, endpoint = match.group(1), match.group(2)
endpoint = "http://" + endpoint + ":80"
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
return Boto3MetaInfo(
is_async=is_async,
handler=None,
bucket_name=bucket_name,
endpoint=endpoint,
file_path=os.path.sep.join(parts[1:]),
async_upload_fn=Boto3Client.async_upload_fileobj,
local_nvme_path=tmp_step_file,
)
def get_local_meta(fp: str) -> LocalMetaInfo:
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
return LocalMetaInfo(None, fp)
def get_mount_point_free_size(path: str):
"""
Returns the remaining space of the temporary storage mount point as a percentage.
Args:
path (str): temporary storage folder path.
Raises:
FileNotFoundError: If the temporary storage folder does not exist,
an error will be reported
"""
if os.path.exists(path):
st = os.statvfs(path)
# f_bavail: Number of free blocks for unprivileged users.
# f_bsize: Filesystem block size.
# return unit is TB.
return st.f_bavail * st.f_bsize / (1024**3)
def check_tmp_folder_accessibility(tmp_local_folder: str):
"""
Check access permissions for temporary storage.
"""
ret = True
if os.path.exists(tmp_local_folder):
ret &= os.access(tmp_local_folder, os.W_OK)
ret &= os.access(tmp_local_folder, os.R_OK)
if ret is False:
error_str = f'{socket.gethostname()} dose not have read and write permissions on {tmp_local_folder}"'
raise RuntimeError(error_str)
class StorageManager(metaclass=SingletonMeta):
"""
Storage Manager for saving or loading checkpoint.
TODO: add a thread to poll the asynchronous storage state.
"""
BACKEND_TYPE = {"boto3", "local"}
@ -279,8 +383,39 @@ class StorageManager(metaclass=SingletonMeta):
}
CLI_DICT = {}
def __init__(self) -> None:
pass
def __init__(self, enable_save, tmp_local_folde="/dev/shm/test/", async_mode=True, n_async_workers=8) -> None:
self._exception_list = []
self._to_be_del_files = []
self._async_stack = []
self.upload_count = 0
self.tmp_local_folder = tmp_local_folde
self.async_mode = async_mode
self.has_warning = False
if enable_save and self.async_mode:
self._async_loop = asyncio.new_event_loop()
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers)
check_tmp_folder_accessibility(os.path.dirname(self.tmp_local_folder))
# Try to create tmp folder
try:
os.makedirs(self.tmp_local_folder, exist_ok=True)
os.chmod(self.tmp_local_folder, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
except FileExistsError:
pass
# In case it is a directory created by other users, we check the permissions again.
check_tmp_folder_accessibility(self.tmp_local_folder)
# Try to clean tmp folder's empty folder.
self.try_delete_tmpfile(self.tmp_local_folder)
# Avaliable storeage space check.
free_size = get_mount_point_free_size(self.tmp_local_folder)
if free_size < 0.1:
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
"""
@ -301,7 +436,7 @@ class StorageManager(metaclass=SingletonMeta):
meta_info = get_local_meta(path)
backend_key = backend
elif backend == "boto3":
meta_info = get_boto3_meta(path)
meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode)
backend_key = backend + ":" + meta_info.endpoint
init_args = (meta_info.endpoint,)
if (
@ -310,10 +445,12 @@ class StorageManager(metaclass=SingletonMeta):
or "HTTP_PROXY" in os.environ
or "HTTPS_PROXY" in os.environ
):
raise RuntimeWarning(
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
the proxy may make boto3 unavailable or affect performance."
)
if not self.has_warning:
logger.warning(
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
the proxy may make boto3 unavailable or affect performance."
)
self.has_warning = True
assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}"
@ -333,19 +470,137 @@ the proxy may make boto3 unavailable or affect performance."
meta = self._get_client(path=folder)
return meta.client.get_fns(*unpack_meta(meta))
def save(self, save_path: str, saved_obj: Any, *args, **kwargs):
def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs):
meta = self._get_client(path=save_path)
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
def load(self, load_path: str, *args, map_location="cpu", **kwargs) -> Any:
if async_upload is None:
async_upload = self.async_mode
if async_upload:
assert (
self.tmp_local_folder
), "StorageManager is not setted tmp_local_folder, so async save cannot be performed."
tmp_step_file = meta.local_nvme_path
self._to_be_del_files.append(tmp_step_file)
with open(tmp_step_file, "wb") as f:
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
else:
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
self.upload_count += 1
def load(self, load_path: str, *args, **kwargs) -> Any:
self.wait()
meta = self._get_client(path=load_path)
return meta.client.load(*unpack_meta(meta), map_location=map_location, *args, **kwargs)
return meta.client.load(*unpack_meta(meta), *args, **kwargs)
def delete_obj(self, fp: str):
meta = self._get_client(path=fp)
meta.client.delete_obj(*unpack_meta(meta))
def _del_tmp_folder(self):
for fp in self._to_be_del_files:
try:
os.remove(fp)
except FileNotFoundError:
pass
except SystemError as e:
logger.error(f'delete file: {fp}, failed for reason:"{e}"')
else:
pass
storage_manager = StorageManager()
def try_delete_tmpfile(self, tmp_dir: str):
"""Delete temporary files in tmp_dir."""
for filename in os.listdir(tmp_dir):
if filename.endswith(".tmpfile"):
file_path = os.path.join(tmp_dir, filename)
try:
os.remove(file_path)
logger.info(f"Delete tmpfile: {file_path}")
except OSError:
# Ignore deletion errors
pass
async def _sync_tasks(self) -> Awaitable[None]:
if not self._async_stack:
return
await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED)
for task in self._async_stack:
try:
task.exception()
except InvalidStateError:
continue
except Exception as e:
file_id = len(self._exception_list)
self._exception_list.append((e, file_id))
logger.error(f"File: {self._to_be_del_files[file_id]}, " f"upload failed with {e}")
self._async_stack.clear()
def async_executor(self, fn: Callable, *args, **kwargs) -> None:
"""
Overview:
Execute task in background, then apppend the future instance in _async_stack.
Arguments:
- fn (:obj:`Callable`): Synchronization fuction.
"""
if not self._async_loop:
raise RuntimeError("Event loop was not initialized, please call this function in async or parallel mode")
t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs)
self._async_stack.append(t)
def wait(self) -> bool:
"""Wait for async operations to complete."""
if not self.async_mode:
return
if self._async_loop:
self._async_loop.run_until_complete(self._sync_tasks())
if self._exception_list:
for file_id, error_msg in self._exception_list:
logger.error(
f"Node:{socket.gethostname()}, Error: Checkpoint {self._to_be_del_files[file_id]} "
f"failed on step {self.upload_count}: {error_msg}"
)
# TODO: Re-upload in sync mode
raise RuntimeError(
f"Failed to upload {self._to_be_del_files[file_id]} " f"on step {self.upload_count}: {error_msg}"
)
self._del_tmp_folder()
self._exception_list.clear()
self._to_be_del_files.clear()
if gpc.is_rank_for_log():
logger.info("all async uploads succeeded!")
self.upload_count += 1
storage_manager: StorageManager = None
def init_storage_manager(ckpt_config):
global storage_manager
storage_manager = StorageManager(
ckpt_config.enable_save_ckpt,
tmp_local_folde=ckpt_config.async_upload_tmp_folder,
async_mode=ckpt_config.async_upload,
)
def get_storage_manager():
assert storage_manager is not None, "storage_manager has not been init!"
return storage_manager
def wait_async_upload_finish():
dist.barrier()
storage_manager.wait()

View File

@ -46,12 +46,12 @@ from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_paral
from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import (
CheckpointSaveManager,
load_context,
load_model_checkpoint,
load_optimizer_checkpoint,
load_sampler,
load_scheduler,
save_checkpoint,
)
from internlm.utils.parallel import (
get_parallel_log_file_name,
@ -432,11 +432,6 @@ def main(args):
label_smoothing = gpc.config.loss.label_smoothing
lr = gpc.config.adam.lr
# ckpt setting
save_ckpt_folder = gpc.config.ckpt.save_ckpt_folder
enable_save_ckpt = gpc.config.ckpt.enable_ckpt
checkpoint_every = gpc.config.ckpt.checkpoint_every
load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None)
load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None)
@ -527,6 +522,14 @@ def main(args):
if load_optimizer:
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
ckpt_save_manager = CheckpointSaveManager(
ckpt_config=gpc.config.ckpt,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
model_config=gpc.config.model,
)
# initialize metric for calculating accuracy and perplexity
metric = AccPerplex(
device=torch.cuda.current_device(),
@ -645,19 +648,10 @@ def main(args):
)
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
# save batch sampler that tracks the true consumed samples
if enable_save_ckpt and train_state.step_count % checkpoint_every == 0:
save_checkpoint(
folder=save_ckpt_folder,
model=model,
optimizer=optimizer,
scheduler=lr_scheduler,
train_state=train_state,
model_config=gpc.config.model,
)
# # save batch sampler that tracks the true consumed samples
ckpt_save_manager.try_save_checkpoint(train_state)
# wait for all checkpoint uploads to be completed
dist.barrier()
ckpt_save_manager.wait_async_upload_finish()
if __name__ == "__main__":