mirror of https://github.com/InternLM/InternLM
feat(ckpt): add auto ckpt load and singal quit (#216)
Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>pull/218/head^2
parent
53648dc0e9
commit
29779c75f0
|
@ -108,67 +108,96 @@ def args_sanity_check():
|
||||||
logger.info(f"valid_every: {data.valid_every}")
|
logger.info(f"valid_every: {data.valid_every}")
|
||||||
|
|
||||||
# processing the checkpoint config
|
# processing the checkpoint config
|
||||||
if "enable_save_ckpt" not in gpc.config.ckpt:
|
ckpt = gpc.config.ckpt
|
||||||
gpc.config.ckpt._add_item("enable_save_ckpt", False)
|
if "enable_save_ckpt" not in ckpt:
|
||||||
|
ckpt._add_item("enable_save_ckpt", False)
|
||||||
|
|
||||||
if "checkpoint_every" not in gpc.config.ckpt or gpc.config.ckpt.checkpoint_every <= 0:
|
# Saving checkpoint args.
|
||||||
gpc.config.ckpt._add_item("checkpoint_every", float("inf"))
|
if ckpt.enable_save_ckpt:
|
||||||
|
assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!"
|
||||||
|
assert ckpt.checkpoint_every > 0
|
||||||
|
assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!"
|
||||||
|
|
||||||
if "load_optimizer" not in gpc.config.ckpt:
|
if "async_upload" not in ckpt:
|
||||||
gpc.config.ckpt._add_item("load_optimizer", True)
|
ckpt._add_item("async_upload", False) # async defalut is False.
|
||||||
|
else:
|
||||||
if "save_ckpt_folder" not in gpc.config.ckpt:
|
if ckpt.async_upload:
|
||||||
gpc.config.ckpt._add_item("save_ckpt_folder", None)
|
assert "save_ckpt_folder" in ckpt
|
||||||
|
if "boto3:" not in ckpt.save_ckpt_folder:
|
||||||
if "load_ckpt_folder" not in gpc.config.ckpt:
|
|
||||||
gpc.config.ckpt._add_item("load_ckpt_folder", None)
|
|
||||||
|
|
||||||
if "load_model_only_folder" not in gpc.config.ckpt:
|
|
||||||
gpc.config.ckpt._add_item("load_model_only_folder", None)
|
|
||||||
|
|
||||||
if "async_upload" not in gpc.config.ckpt:
|
|
||||||
gpc.config.ckpt._add_item("async_upload", False)
|
|
||||||
|
|
||||||
if "async_upload_tmp_folder" not in gpc.config.ckpt:
|
|
||||||
gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
|
|
||||||
|
|
||||||
if gpc.config.ckpt.async_upload:
|
|
||||||
assert "save_ckpt_folder" in gpc.config.ckpt
|
|
||||||
if "boto3:" not in gpc.config.ckpt.save_ckpt_folder:
|
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.warning("Storing ckpt on file system does not support asynchronous storage, will use sync save!")
|
logger.warning(
|
||||||
gpc.config.ckpt.async_upload = False
|
"Storing ckpt on file system does not support asynchronous storage, will use sync save!"
|
||||||
|
)
|
||||||
|
ckpt.async_upload = False
|
||||||
|
else:
|
||||||
|
if "async_upload_tmp_folder" not in ckpt:
|
||||||
|
ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
|
||||||
|
|
||||||
if "snapshot_ckpt_folder" not in gpc.config.ckpt:
|
if not ckpt.async_upload:
|
||||||
gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder, "snapshot"))
|
ckpt._add_item("async_upload_tmp_folder", None)
|
||||||
|
|
||||||
if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"):
|
if "snapshot_ckpt_folder" not in ckpt:
|
||||||
gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2)
|
ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot"))
|
||||||
assert gpc.config.ckpt.oss_snapshot_freq > 0
|
|
||||||
|
|
||||||
assert not (
|
if "oss_snapshot_freq" not in ckpt:
|
||||||
gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None
|
ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable.
|
||||||
), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time."
|
else:
|
||||||
|
ckpt._add_item("checkpoint_every", float("inf"))
|
||||||
|
ckpt._add_item("oss_snapshot_freq", float("inf"))
|
||||||
|
ckpt._add_item("save_ckpt_folder", None)
|
||||||
|
ckpt._add_item("async_upload", False)
|
||||||
|
ckpt._add_item("async_upload_tmp_folder", None)
|
||||||
|
ckpt._add_item("snapshot_ckpt_folder", None)
|
||||||
|
ckpt._add_item("snapshot_ckpt_folder", None)
|
||||||
|
|
||||||
|
# Loading checkpoint args.
|
||||||
|
if "load_model_only_folder" not in ckpt:
|
||||||
|
ckpt._add_item("load_model_only_folder", None)
|
||||||
|
|
||||||
|
if "load_ckpt_folder" not in ckpt:
|
||||||
|
ckpt._add_item("load_ckpt_folder", None)
|
||||||
|
|
||||||
|
if "load_optimizer" not in ckpt:
|
||||||
|
ckpt._add_item("load_optimizer", True)
|
||||||
|
|
||||||
|
if "stop_file_path" not in ckpt:
|
||||||
|
ckpt._add_item("stop_file_path", None)
|
||||||
|
|
||||||
|
if "load_given_ckpt" not in ckpt:
|
||||||
|
# If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity
|
||||||
|
# to auto-load latest checkpoint.
|
||||||
|
ckpt._add_item("load_given_ckpt", False)
|
||||||
|
|
||||||
|
if ckpt.load_given_ckpt:
|
||||||
|
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
|
||||||
|
if ckpt.load_ckpt_folder and ckpt.load_model_only_folder:
|
||||||
|
logger.warning(
|
||||||
|
"Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
|
||||||
|
and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
||||||
|
)
|
||||||
|
ckpt.load_model_only_folder = None
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
|
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
|
||||||
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_save_ckpt}")
|
logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}")
|
||||||
logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}")
|
logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}")
|
||||||
logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}")
|
logger.info(f"checkpoint_every: {ckpt.checkpoint_every}")
|
||||||
logger.info(f"async_upload: {gpc.config.ckpt.async_upload}")
|
logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}")
|
||||||
if gpc.config.ckpt.async_upload:
|
|
||||||
logger.info(f"async_upload_tmp_folder: {gpc.config.ckpt.async_upload_tmp_folder}")
|
|
||||||
|
|
||||||
# initialization storage manager
|
# initialization storage manager
|
||||||
init_storage_manager(gpc.config.ckpt)
|
init_storage_manager(ckpt)
|
||||||
|
|
||||||
# tensorboard writer config
|
# tensorboard writer config
|
||||||
if "enable_tb" not in gpc.config:
|
if "enable_tb" not in gpc.config:
|
||||||
gpc.config._add_item("enable_tb", True)
|
gpc.config._add_item("enable_tb", True)
|
||||||
if "tensorboard_folder" not in gpc.config:
|
if "tensorboard_folder" not in gpc.config:
|
||||||
gpc.config._add_item("tensorboard_folder", None)
|
gpc.config._add_item(
|
||||||
|
"tensorboard_folder", os.environ["tensorboard_folder"] if "tensorboard_folder" in os.environ else None
|
||||||
|
)
|
||||||
if "resume_tb_folder" not in gpc.config:
|
if "resume_tb_folder" not in gpc.config:
|
||||||
gpc.config._add_item("resume_tb_folder", None)
|
gpc.config._add_item(
|
||||||
|
"resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None
|
||||||
|
)
|
||||||
|
|
||||||
# cudnn
|
# cudnn
|
||||||
torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
|
torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
|
||||||
|
|
|
@ -2,7 +2,9 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import fcntl
|
||||||
import os
|
import os
|
||||||
|
import socket
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
@ -12,6 +14,7 @@ import torch
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.trainer import TrainState
|
from internlm.core.trainer import TrainState
|
||||||
|
from internlm.monitor import send_alert_message
|
||||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||||
from internlm.utils.common import get_current_device
|
from internlm.utils.common import get_current_device
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
@ -25,8 +28,6 @@ from internlm.utils.storage_manager import (
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
quit_signal_handler = None
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointType(Enum):
|
class CheckpointType(Enum):
|
||||||
NORMAL_CHECKPOINT = 1
|
NORMAL_CHECKPOINT = 1
|
||||||
|
@ -167,44 +168,6 @@ def save_optimizer_checkpoint(optim, state_path):
|
||||||
llm_save(os.path.join(state_path, fp), states)
|
llm_save(os.path.join(state_path, fp), states)
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None):
|
|
||||||
"""
|
|
||||||
Save checkpoint to the given folder path.
|
|
||||||
"""
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
torch.distributed.barrier()
|
|
||||||
folder = os.path.join(folder, str(train_state.step_count))
|
|
||||||
logger.info(
|
|
||||||
f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count} from rank:{gpc.get_global_rank()}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
timer("save-model").start()
|
|
||||||
save_model_checkpoint(folder=folder, model=model)
|
|
||||||
timer("save-model").stop()
|
|
||||||
|
|
||||||
timer("save-optimizer").start()
|
|
||||||
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
|
|
||||||
timer("save-optimizer").stop()
|
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
|
||||||
scheduler_states = scheduler.state_dict()
|
|
||||||
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
|
|
||||||
|
|
||||||
sampler_state = train_state.batch_sampler.state_dict()
|
|
||||||
llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state)
|
|
||||||
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
|
|
||||||
|
|
||||||
if model_config is not None:
|
|
||||||
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
|
|
||||||
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
|
||||||
timer.log(["save-model", "save-optimizer"], logger=logger)
|
|
||||||
logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s")
|
|
||||||
|
|
||||||
|
|
||||||
def load_optimizer_checkpoint(folder, optim):
|
def load_optimizer_checkpoint(folder, optim):
|
||||||
"""Load the optimizer state from the local file system or remote
|
"""Load the optimizer state from the local file system or remote
|
||||||
object storage Service (OSS).
|
object storage Service (OSS).
|
||||||
|
@ -304,19 +267,12 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
|
||||||
logger.info(f"reload load_scheduler:{lr_scheduler}")
|
logger.info(f"reload load_scheduler:{lr_scheduler}")
|
||||||
|
|
||||||
|
|
||||||
class CheckpointSaveManager:
|
class CheckpointManager:
|
||||||
"""StorageManagerContext"""
|
"""StorageManagerContext"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, ckpt_config, model, model_config, feishu_address=None) -> None:
|
||||||
self,
|
|
||||||
ckpt_config,
|
|
||||||
model,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler,
|
|
||||||
model_config,
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
CheckpointSaveManager is used to decide when to store ckpt. If it is an asynchronous
|
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
|
||||||
upload mode, you must call wait_async_upload_finish at the end of the program to wait
|
upload mode, you must call wait_async_upload_finish at the end of the program to wait
|
||||||
for the asynchronous ckpt upload to complete.
|
for the asynchronous ckpt upload to complete.
|
||||||
|
|
||||||
|
@ -332,26 +288,95 @@ class CheckpointSaveManager:
|
||||||
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
|
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
|
||||||
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
|
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
|
||||||
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
|
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
|
||||||
|
self.stop_file_path = ckpt_config.stop_file_path
|
||||||
|
self.load_model_only_folder = ckpt_config.load_model_only_folder
|
||||||
|
self.feishu_address = feishu_address
|
||||||
self.storage_manager = get_storage_manager()
|
self.storage_manager = get_storage_manager()
|
||||||
self.snapshot_counter = 0
|
self.snapshot_counter = 0
|
||||||
|
self.load_optimizer = gpc.config.ckpt.load_optimizer
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optimizer = optimizer
|
|
||||||
self.lr_scheduler = lr_scheduler
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
|
|
||||||
|
if self.stop_file_path and gpc.get_global_rank() == 0:
|
||||||
|
dir_path = os.path.dirname(self.stop_file_path)
|
||||||
|
if dir_path != "" and not os.path.exists(dir_path):
|
||||||
|
os.makedirs(dir_path)
|
||||||
|
with open(self.stop_file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write("0")
|
||||||
|
|
||||||
|
if ckpt_config.load_given_ckpt is False:
|
||||||
|
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
|
||||||
|
latest_ckpt_path = self.query_lastest_ckpt()
|
||||||
|
if latest_ckpt_path:
|
||||||
|
self.load_ckpt_folder = latest_ckpt_path
|
||||||
|
else:
|
||||||
|
# At this time, we have to load model init weights and train from step 0.
|
||||||
|
self.load_ckpt_folder = self.load_model_only_folder
|
||||||
|
else:
|
||||||
|
self.load_ckpt_folder = ckpt_config.load_ckpt_folder
|
||||||
|
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'")
|
||||||
|
if self.stop_file_path is None:
|
||||||
|
logger.warning("no set stop_file_path, quit_signal_handler is disable")
|
||||||
|
|
||||||
|
def quit_signal_handler(self, train_state) -> bool:
|
||||||
|
"""
|
||||||
|
Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file,
|
||||||
|
all ranks will save ckpt and exit.
|
||||||
|
Negative integer step means save ckpt.
|
||||||
|
Positive integer step means save ckpt and quit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_state (TrainState):
|
||||||
|
Returns:
|
||||||
|
bool: whether to quit.
|
||||||
|
"""
|
||||||
|
now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT
|
||||||
|
|
||||||
|
if self.stop_file_path is None:
|
||||||
|
return now_break, now_save_ckpt, save_type
|
||||||
|
|
||||||
|
with open(self.stop_file_path, "a+", encoding="utf-8") as f:
|
||||||
|
fcntl.flock(f, fcntl.LOCK_EX)
|
||||||
|
f.seek(0)
|
||||||
|
msg = f.read()
|
||||||
|
fcntl.flock(f, fcntl.LOCK_UN)
|
||||||
|
action_step = int(msg)
|
||||||
|
|
||||||
|
if action_step < 0 and abs(action_step) == train_state.step_count:
|
||||||
|
now_save_ckpt = True
|
||||||
|
|
||||||
|
if action_step > 0 and action_step == train_state.step_count:
|
||||||
|
now_break, now_save_ckpt = True, True
|
||||||
|
|
||||||
|
if action_step != 0 and gpc.is_rank_for_log():
|
||||||
|
msg = "Stop" if action_step > 0 else "Save"
|
||||||
|
action_step = abs(action_step)
|
||||||
|
if train_state.step_count <= action_step:
|
||||||
|
if self.feishu_address:
|
||||||
|
send_alert_message(
|
||||||
|
address=self.feishu_address,
|
||||||
|
message=f"training will {msg} at step_count {action_step}!\
|
||||||
|
now step_count is {train_state.step_count}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return now_break, now_save_ckpt, save_type
|
||||||
|
|
||||||
def try_save_checkpoint(self, train_state):
|
def try_save_checkpoint(self, train_state):
|
||||||
if not self.enable_save_ckpt:
|
if not self.enable_save_ckpt:
|
||||||
return
|
return False
|
||||||
|
|
||||||
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
|
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
|
||||||
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
|
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
|
||||||
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
|
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
|
||||||
if train_state.step_count % self.checkpoint_every == 0:
|
if train_state.step_count % self.checkpoint_every == 0:
|
||||||
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
|
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
|
||||||
|
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state)
|
||||||
if save_ckpts is False:
|
if save_ckpts is False:
|
||||||
if quit_signal_handler is not None:
|
save_ckpts = singal_save_ckpts
|
||||||
save_ckpts, save_type = quit_signal_handler(train_state)
|
save_type = singal_save_type
|
||||||
|
|
||||||
if save_ckpts:
|
if save_ckpts:
|
||||||
# Wait for the previous round of asynchronous upload storage to complete.
|
# Wait for the previous round of asynchronous upload storage to complete.
|
||||||
|
@ -361,9 +386,9 @@ class CheckpointSaveManager:
|
||||||
self.snapshot_counter = (self.snapshot_counter + 1) % 2
|
self.snapshot_counter = (self.snapshot_counter + 1) % 2
|
||||||
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
|
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
|
||||||
else:
|
else:
|
||||||
save_ckpt_folder = self.save_ckpt_folder
|
save_ckpt_folder = os.path.join(self.save_ckpt_folder, str(train_state.step_count))
|
||||||
|
|
||||||
save_checkpoint(
|
self.save_checkpoint(
|
||||||
folder=save_ckpt_folder,
|
folder=save_ckpt_folder,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
optimizer=self.optimizer,
|
optimizer=self.optimizer,
|
||||||
|
@ -372,7 +397,220 @@ class CheckpointSaveManager:
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return now_break
|
||||||
|
|
||||||
def wait_async_upload_finish(self):
|
def wait_async_upload_finish(self):
|
||||||
"""wait for all checkpoint uploads to be completed"""
|
"""wait for all checkpoint uploads to be completed"""
|
||||||
self.storage_manager.wait()
|
self.storage_manager.wait()
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
def query_latest_snapshot_step_boto3(self):
|
||||||
|
"""query_latest_snapshot_step_boto3
|
||||||
|
Returns:
|
||||||
|
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
|
||||||
|
"""
|
||||||
|
ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder)
|
||||||
|
if len(ckpt_list) == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
max_normal_step = 0
|
||||||
|
ckpt_list = list(map(lambda a: int(a.strip("/")) if a.strip("/").isdigit() else 0, ckpt_list))
|
||||||
|
ckpt_list.sort(reverse=True)
|
||||||
|
for ckpt in ckpt_list:
|
||||||
|
fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt)))
|
||||||
|
for fn in fns_list:
|
||||||
|
if fn.endswith(".step"):
|
||||||
|
max_normal_step = ckpt
|
||||||
|
break
|
||||||
|
if max_normal_step != 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
max_normal_step = ckpt_list[0]
|
||||||
|
load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step))
|
||||||
|
|
||||||
|
snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0")
|
||||||
|
snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1")
|
||||||
|
ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0)
|
||||||
|
ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1)
|
||||||
|
max_step_0, max_step_1 = 0, 0
|
||||||
|
for ckpt in ckpt_list_1:
|
||||||
|
ckpt = ckpt.strip("/")
|
||||||
|
if ckpt.endswith(".step"):
|
||||||
|
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
|
||||||
|
for ckpt in ckpt_list_2:
|
||||||
|
ckpt = ckpt.strip("/")
|
||||||
|
if ckpt.endswith(".step"):
|
||||||
|
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
|
||||||
|
|
||||||
|
snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1
|
||||||
|
snap_step = max(max_step_0, max_step_1)
|
||||||
|
load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path
|
||||||
|
load_step = max(snap_step, max_normal_step)
|
||||||
|
return load_path, load_step
|
||||||
|
|
||||||
|
def query_latest_snapshot_step_local(self):
|
||||||
|
max_step, max_step_path = 0, None
|
||||||
|
for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True):
|
||||||
|
for fn in files:
|
||||||
|
fn = fn.strip("/")
|
||||||
|
if fn.endswith(".step"):
|
||||||
|
# We assume that both normal ckpt and snapshot ckpt will store the '.step' file
|
||||||
|
# as an integrity flag.
|
||||||
|
step = int(fn.rsplit(".", maxsplit=1)[0])
|
||||||
|
if max_step < step:
|
||||||
|
max_step = step
|
||||||
|
max_step_path = root
|
||||||
|
|
||||||
|
return max_step_path, max_step
|
||||||
|
|
||||||
|
def query_lastest_ckpt(self):
|
||||||
|
latest_checkpoint = None
|
||||||
|
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
|
||||||
|
if self.save_ckpt_folder:
|
||||||
|
if self.save_ckpt_folder.startswith("boto3"):
|
||||||
|
latest_checkpoint, step = self.query_latest_snapshot_step_boto3()
|
||||||
|
elif self.save_ckpt_folder.startswith("local"):
|
||||||
|
latest_checkpoint, step = self.query_latest_snapshot_step_local()
|
||||||
|
else:
|
||||||
|
latest_checkpoint, step = None, 0
|
||||||
|
|
||||||
|
if latest_checkpoint is not None:
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}")
|
||||||
|
send_alert_message(
|
||||||
|
address=self.feishu_address,
|
||||||
|
message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
send_alert_message(
|
||||||
|
address=self.feishu_address,
|
||||||
|
message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return latest_checkpoint
|
||||||
|
|
||||||
|
def try_load_model(self, current_time=""):
|
||||||
|
model_load_path = None
|
||||||
|
|
||||||
|
if self.load_ckpt_folder and self.load_model_only_folder:
|
||||||
|
raise ValueError(
|
||||||
|
"Error, try to use both load_ckpt_folder and load_model_only_folder paths, \
|
||||||
|
if you only need to load model weights (for example starting an SFT task for the first time), \
|
||||||
|
set load_model_only_folder path, if you need to resume training from ckpt, \
|
||||||
|
set load_ckpt_folder or use default value \
|
||||||
|
(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.load_ckpt_folder:
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.info(
|
||||||
|
f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:"
|
||||||
|
f"{socket.gethostname()}==========="
|
||||||
|
)
|
||||||
|
model_load_path = self.load_ckpt_folder
|
||||||
|
elif self.load_model_only_folder:
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.info(
|
||||||
|
f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:"
|
||||||
|
f"{socket.gethostname()}==========="
|
||||||
|
)
|
||||||
|
model_load_path = self.load_model_only_folder
|
||||||
|
else:
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.info(
|
||||||
|
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
||||||
|
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||||
|
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loading model weights must be done before zero is initialized.
|
||||||
|
if model_load_path is not None:
|
||||||
|
load_model_checkpoint(folder=model_load_path, model=self.model)
|
||||||
|
|
||||||
|
def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl):
|
||||||
|
"""Attempt to restore the training state of the last ckpt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lr_scheduler (_LRScheduler): lr_scheduler object.
|
||||||
|
optimizer (Optimizer): optimizer object.
|
||||||
|
lr (float): learning rate.
|
||||||
|
train_state (dict): traing states.
|
||||||
|
train_dl (DataLoader): traning dataloader object
|
||||||
|
"""
|
||||||
|
if self.load_ckpt_folder is not None:
|
||||||
|
# load optimzier states.
|
||||||
|
if self.load_optimizer:
|
||||||
|
load_optimizer_checkpoint(self.load_ckpt_folder, optimizer)
|
||||||
|
# load lr scheduler states.
|
||||||
|
load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
|
||||||
|
# load training states.
|
||||||
|
load_context(self.load_ckpt_folder, train_dl, train_state)
|
||||||
|
# load dataloader sampler states.
|
||||||
|
if hasattr(train_state, "batch_sampler") and not isinstance(
|
||||||
|
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
|
||||||
|
):
|
||||||
|
load_sampler(self.load_ckpt_folder, train_dl.batch_sampler)
|
||||||
|
if hasattr(train_state, "data_state_dict"):
|
||||||
|
train_dl.dataset.load_state_dict(
|
||||||
|
llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder
|
||||||
|
)
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
|
||||||
|
def save_checkpoint(self, folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None):
|
||||||
|
"""
|
||||||
|
Save checkpoint to the given folder path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
self.set_save_folder(folder, train_state.step_count)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...")
|
||||||
|
|
||||||
|
timer("save-model").start()
|
||||||
|
save_model_checkpoint(folder=folder, model=model)
|
||||||
|
timer("save-model").stop()
|
||||||
|
|
||||||
|
timer("save-optimizer").start()
|
||||||
|
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
|
||||||
|
timer("save-optimizer").stop()
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(train_state, "data_state_dict")
|
||||||
|
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
|
||||||
|
and gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
||||||
|
):
|
||||||
|
llm_save(
|
||||||
|
os.path.join(folder, f"sampler_{gpc.get_local_rank(ParallelMode.DATA)}.pt"),
|
||||||
|
saved_obj=train_state.data_state_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
scheduler_states = scheduler.state_dict()
|
||||||
|
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
|
||||||
|
if hasattr(train_state, "batch_sampler") and not isinstance(
|
||||||
|
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
|
||||||
|
):
|
||||||
|
sampler_state = train_state.batch_sampler.state_dict()
|
||||||
|
llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state)
|
||||||
|
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
|
||||||
|
|
||||||
|
if model_config is not None:
|
||||||
|
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
|
||||||
|
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
timer.log(["save-model", "save-optimizer"], logger=logger)
|
||||||
|
logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s")
|
||||||
|
if self.storage_manager.async_mode is False:
|
||||||
|
llm_save(
|
||||||
|
os.path.join(folder, f"{train_state.step_count}.step"),
|
||||||
|
saved_obj=dict({"step": train_state.step_count}),
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_save_folder(self, folder, step):
|
||||||
|
self.storage_manager.latest_save_folder = folder
|
||||||
|
self.storage_manager.latest_save_step = step
|
||||||
|
|
|
@ -15,8 +15,6 @@ from asyncio.tasks import ALL_COMPLETED
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Union
|
from typing import Any, Awaitable, Callable, Dict, List, Union
|
||||||
|
|
||||||
import boto3
|
|
||||||
import botocore
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
@ -24,6 +22,13 @@ from internlm.core.context import global_context as gpc
|
||||||
from internlm.utils.common import SingletonMeta
|
from internlm.utils.common import SingletonMeta
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
import botocore
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
|
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
|
||||||
|
@ -234,13 +239,13 @@ class Boto3Client(StorageClient):
|
||||||
"""
|
"""
|
||||||
paginator = handler.client.get_paginator("list_objects_v2")
|
paginator = handler.client.get_paginator("list_objects_v2")
|
||||||
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
|
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
|
||||||
|
|
||||||
folder_name_list = []
|
folder_name_list = []
|
||||||
for page in pages:
|
for page in pages:
|
||||||
|
if "Contents" in page:
|
||||||
for obj in page["Contents"]:
|
for obj in page["Contents"]:
|
||||||
fp: str = obj["Key"]
|
pth: str = obj["Key"]
|
||||||
folder_name_list.append(fp.rsplit("/", maxsplit=1)[1])
|
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
|
||||||
return folder_name_list
|
return list(set(folder_name_list))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
|
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
|
||||||
|
@ -391,6 +396,11 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
self.tmp_local_folder = tmp_local_folder
|
self.tmp_local_folder = tmp_local_folder
|
||||||
self.async_mode = async_mode
|
self.async_mode = async_mode
|
||||||
self.has_warning = False
|
self.has_warning = False
|
||||||
|
self._async_loop = None
|
||||||
|
self._thread_pool = None
|
||||||
|
self.latest_save_folder = None
|
||||||
|
self.latest_save_step = 0
|
||||||
|
self.async_task_peeding = False
|
||||||
|
|
||||||
if enable_save and self.async_mode:
|
if enable_save and self.async_mode:
|
||||||
self._async_loop = asyncio.new_event_loop()
|
self._async_loop = asyncio.new_event_loop()
|
||||||
|
@ -485,6 +495,7 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
|
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
|
||||||
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
||||||
|
self.async_task_peeding = True
|
||||||
else:
|
else:
|
||||||
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
|
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
|
||||||
self.upload_count += 1
|
self.upload_count += 1
|
||||||
|
@ -523,23 +534,22 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _sync_tasks(self) -> Awaitable[None]:
|
async def _sync_tasks(self) -> Awaitable[None]:
|
||||||
if not self._async_stack:
|
if self._async_stack:
|
||||||
return
|
|
||||||
|
|
||||||
await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED)
|
await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED)
|
||||||
|
count = 0
|
||||||
for task in self._async_stack:
|
while self._async_stack:
|
||||||
|
t = self._async_stack[0]
|
||||||
try:
|
try:
|
||||||
task.exception()
|
e = t.exception()
|
||||||
|
if e:
|
||||||
|
self._exception_list.append((e, count))
|
||||||
|
logger.error(f"File:{self._to_be_del_files[count]}, upload failed for {e}")
|
||||||
|
# raise e
|
||||||
|
count += 1
|
||||||
|
self._async_stack.pop(0)
|
||||||
except InvalidStateError:
|
except InvalidStateError:
|
||||||
continue
|
# Not finished. https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.exception
|
||||||
except Exception as e:
|
pass
|
||||||
file_id = len(self._exception_list)
|
|
||||||
self._exception_list.append((e, file_id))
|
|
||||||
|
|
||||||
logger.error(f"File: {self._to_be_del_files[file_id]}, " f"upload failed with {e}")
|
|
||||||
|
|
||||||
self._async_stack.clear()
|
|
||||||
|
|
||||||
def async_executor(self, fn: Callable, *args, **kwargs) -> None:
|
def async_executor(self, fn: Callable, *args, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -559,11 +569,14 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
if not self.async_mode:
|
if not self.async_mode:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not self.async_task_peeding:
|
||||||
|
return
|
||||||
|
|
||||||
if self._async_loop:
|
if self._async_loop:
|
||||||
self._async_loop.run_until_complete(self._sync_tasks())
|
self._async_loop.run_until_complete(self._sync_tasks())
|
||||||
|
|
||||||
if self._exception_list:
|
if self._exception_list:
|
||||||
for file_id, error_msg in self._exception_list:
|
for error_msg, file_id in self._exception_list:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Node:{socket.gethostname()}, Error: Checkpoint {self._to_be_del_files[file_id]} "
|
f"Node:{socket.gethostname()}, Error: Checkpoint {self._to_be_del_files[file_id]} "
|
||||||
f"failed on step {self.upload_count}: {error_msg}"
|
f"failed on step {self.upload_count}: {error_msg}"
|
||||||
|
@ -577,10 +590,16 @@ class StorageManager(metaclass=SingletonMeta):
|
||||||
self._del_tmp_folder()
|
self._del_tmp_folder()
|
||||||
self._exception_list.clear()
|
self._exception_list.clear()
|
||||||
self._to_be_del_files.clear()
|
self._to_be_del_files.clear()
|
||||||
|
self.async_task_peeding = False
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info("all async uploads succeeded!")
|
|
||||||
self.upload_count += 1
|
self.upload_count += 1
|
||||||
|
if self.async_mode:
|
||||||
|
self.save(
|
||||||
|
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
|
||||||
|
saved_obj=dict({"step": self.latest_save_step}),
|
||||||
|
async_upload=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
storage_manager: StorageManager = None
|
storage_manager: StorageManager = None
|
||||||
|
|
|
@ -11,10 +11,6 @@ from torch.utils.tensorboard import SummaryWriter
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
|
||||||
|
|
||||||
def copy_ignore_folder(source_path, target_path):
|
|
||||||
os.system(f"cp -r {source_path}/* {target_path}/")
|
|
||||||
|
|
||||||
|
|
||||||
def tb_save_run_info(writer, config_lines, global_step=0):
|
def tb_save_run_info(writer, config_lines, global_step=0):
|
||||||
writer.add_text(tag="cmd", text_string=" ".join(sys.argv[:]), global_step=global_step)
|
writer.add_text(tag="cmd", text_string=" ".join(sys.argv[:]), global_step=global_step)
|
||||||
lines = []
|
lines = []
|
||||||
|
@ -44,7 +40,8 @@ def init_tb_writer(
|
||||||
if gpc.get_global_rank() == 0:
|
if gpc.get_global_rank() == 0:
|
||||||
if resume_tb_folder is not None:
|
if resume_tb_folder is not None:
|
||||||
logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}...")
|
logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}...")
|
||||||
copy_ignore_folder(resume_tb_folder, tb_folder)
|
os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/")
|
||||||
|
os.system(f"chmod -R +w {tb_folder}/")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Login tensorboard logs to: {tb_folder}")
|
logger.info(f"Login tensorboard logs to: {tb_folder}")
|
||||||
|
|
||||||
|
|
68
train.py
68
train.py
|
@ -47,14 +47,7 @@ from internlm.utils.common import (
|
||||||
from internlm.utils.evaluation import evaluate_on_val_dls
|
from internlm.utils.evaluation import evaluate_on_val_dls
|
||||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.utils.model_checkpoint import (
|
from internlm.utils.model_checkpoint import CheckpointManager
|
||||||
CheckpointSaveManager,
|
|
||||||
load_context,
|
|
||||||
load_model_checkpoint,
|
|
||||||
load_optimizer_checkpoint,
|
|
||||||
load_sampler,
|
|
||||||
load_scheduler,
|
|
||||||
)
|
|
||||||
from internlm.utils.parallel import (
|
from internlm.utils.parallel import (
|
||||||
get_parallel_log_file_name,
|
get_parallel_log_file_name,
|
||||||
is_no_pp_or_last_stage,
|
is_no_pp_or_last_stage,
|
||||||
|
@ -462,13 +455,9 @@ def main(args):
|
||||||
skip_batches = gpc.config.data.skip_batches
|
skip_batches = gpc.config.data.skip_batches
|
||||||
total_steps = gpc.config.data.total_steps
|
total_steps = gpc.config.data.total_steps
|
||||||
valid_every = gpc.config.data.valid_every
|
valid_every = gpc.config.data.valid_every
|
||||||
load_optimizer = gpc.config.ckpt.load_optimizer
|
|
||||||
label_smoothing = gpc.config.loss.label_smoothing
|
label_smoothing = gpc.config.loss.label_smoothing
|
||||||
lr = gpc.config.adam.lr
|
lr = gpc.config.adam.lr
|
||||||
|
|
||||||
load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None)
|
|
||||||
load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None)
|
|
||||||
|
|
||||||
get_tflops_func = partial(
|
get_tflops_func = partial(
|
||||||
get_megatron_flops,
|
get_megatron_flops,
|
||||||
checkpoint=gpc.config.model.checkpoint,
|
checkpoint=gpc.config.model.checkpoint,
|
||||||
|
@ -504,32 +493,19 @@ def main(args):
|
||||||
enable_tb=gpc.config.enable_tb,
|
enable_tb=gpc.config.enable_tb,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_load_path = None
|
|
||||||
if load_resume_ckpt_folder is not None:
|
|
||||||
logger.info(
|
|
||||||
f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:"
|
|
||||||
f"{socket.gethostname()}==========="
|
|
||||||
)
|
|
||||||
model_load_path = load_resume_ckpt_folder
|
|
||||||
elif load_model_only_folder is not None:
|
|
||||||
logger.info(
|
|
||||||
f"===========SFT training from `{load_model_only_folder}` {current_time} on host:"
|
|
||||||
f"{socket.gethostname()}==========="
|
|
||||||
)
|
|
||||||
model_load_path = load_model_only_folder
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
|
||||||
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
|
||||||
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize and resume train state
|
# initialize and resume train state
|
||||||
train_state = TrainState(gpc.config)
|
train_state = TrainState(gpc.config)
|
||||||
|
|
||||||
# initialize model
|
# initialize model
|
||||||
model = initialize_model()
|
model = initialize_model()
|
||||||
|
|
||||||
|
ckpt_manager = CheckpointManager(
|
||||||
|
ckpt_config=gpc.config.ckpt,
|
||||||
|
model=model,
|
||||||
|
model_config=gpc.config.model,
|
||||||
|
feishu_address=gpc.config.alert_address,
|
||||||
|
)
|
||||||
|
|
||||||
# initialize loss function
|
# initialize loss function
|
||||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
||||||
|
|
||||||
|
@ -539,30 +515,12 @@ def main(args):
|
||||||
train_state.init_batch_sampler(train_dl)
|
train_state.init_batch_sampler(train_dl)
|
||||||
|
|
||||||
# Loading model weights must be done before zero is initialized.
|
# Loading model weights must be done before zero is initialized.
|
||||||
if model_load_path is not None:
|
ckpt_manager.try_load_model(current_time)
|
||||||
load_model_checkpoint(folder=model_load_path, model=model)
|
|
||||||
|
|
||||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||||
|
|
||||||
# Loading other persistent training states.
|
# Loading other persistent training states.
|
||||||
if load_resume_ckpt_folder is not None:
|
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
|
||||||
# load lr scheduler states.
|
|
||||||
load_scheduler(load_resume_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
|
|
||||||
# load training states.
|
|
||||||
load_context(load_resume_ckpt_folder, train_dl, train_state)
|
|
||||||
# load dataloader sampler states.
|
|
||||||
load_sampler(load_resume_ckpt_folder, train_dl.batch_sampler)
|
|
||||||
# load optimzier states.
|
|
||||||
if load_optimizer:
|
|
||||||
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
|
|
||||||
|
|
||||||
ckpt_save_manager = CheckpointSaveManager(
|
|
||||||
ckpt_config=gpc.config.ckpt,
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
model_config=gpc.config.model,
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize metric for calculating accuracy and perplexity
|
# initialize metric for calculating accuracy and perplexity
|
||||||
metric = AccPerplex(
|
metric = AccPerplex(
|
||||||
|
@ -700,14 +658,16 @@ def main(args):
|
||||||
|
|
||||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||||
# # save batch sampler that tracks the true consumed samples
|
# # save batch sampler that tracks the true consumed samples
|
||||||
ckpt_save_manager.try_save_checkpoint(train_state)
|
now_break = ckpt_manager.try_save_checkpoint(train_state)
|
||||||
|
if now_break:
|
||||||
|
break
|
||||||
|
|
||||||
if memory_profiler is not None:
|
if memory_profiler is not None:
|
||||||
memory_profiler.step()
|
memory_profiler.step()
|
||||||
|
|
||||||
prof.step()
|
prof.step()
|
||||||
|
|
||||||
ckpt_save_manager.wait_async_upload_finish()
|
ckpt_manager.wait_async_upload_finish()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue