feat(ckpt): support auto resume in Volc and Ali (#529)

* multipart upload

* upload

* storage

* storage

* storage

* storage

* change ak sk name

* change ak sk name

* change ak sk name

* change ak sk name

* storage

* storage

* auto resume

* auto resume

* auto resume

* bug
pull/538/head
jiaxingli 2023-12-12 13:27:24 +08:00 committed by GitHub
parent cc5b15349d
commit d904730be7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 8 deletions

View File

@ -1016,7 +1016,8 @@ now step_count is {train_state.step_count}",
torch.distributed.barrier()
def query_latest_snapshot_step_boto3(self):
"""query_latest_snapshot_step_boto3
"""Query the latest snapshot step from the storage backend.
Currently, we only support the following storage backends: boto3, oss2 and volc.
Returns:
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
"""
@ -1074,6 +1075,7 @@ now step_count is {train_state.step_count}",
return load_path, max(snap_step, max_normal_step)
def query_latest_snapshot_step_local(self):
"""Query the latest snapshot step from the local file system."""
max_step, max_step_path = 0, None
save_ckpt_folder = self.save_ckpt_folder.split(":")[1]
for root, _, files in os.walk(save_ckpt_folder, followlinks=True):
@ -1090,18 +1092,22 @@ now step_count is {train_state.step_count}",
return max_step_path, max_step
def query_lastest_ckpt(self):
"""Query the latest ckpt via the storage backend."""
latest_ckpt, step = None, -1
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
if self.save_ckpt_folder:
backend, _ = try_get_storage_backend(self.save_ckpt_folder)
if backend == "boto3":
if backend in ["boto3", "oss2", "volc"]:
latest_ckpt, step = self.query_latest_snapshot_step_boto3()
if latest_ckpt and not latest_ckpt.startswith("boto3:"):
latest_ckpt = ":".join(["boto3", latest_ckpt])
elif backend == "local":
latest_ckpt, step = self.query_latest_snapshot_step_local()
if latest_ckpt and not latest_ckpt.startswith("local:"):
latest_ckpt = ":".join(["local", latest_ckpt])
else:
raise NotImplementedError(
f"Unsupported backend: {backend}, " "Currently only support `boto3`, `oss2`, `volc` and `local`"
)
if latest_ckpt and not latest_ckpt.startswith(backend + ":"):
latest_ckpt = ":".join([backend, latest_ckpt])
if gpc.is_rank_for_log():
logger.info(f"Found latest ckpt {latest_ckpt if latest_ckpt else 'None'}, step: {step}...")

View File

@ -739,10 +739,9 @@ class AliClient(StorageClient):
if AliClient.is_fp_exists(handler, fp):
folder_name_list = []
for obj in handler.handler.ObjectIteratorV2(handler.client, prefix=fp):
folder_name_list.append(obj.key.split("/")[-1])
folder_name_list.append(obj.key.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
return list(set(folder_name_list))
else:
if is_rank_for_log():
logger.warning(f"'{fp}' not found!")