feat(ckpt): add train config into ckpt (#231)

pull/233/head
Guoteng 2023-08-24 19:57:32 +08:00 committed by GitHub
parent 29dd401071
commit 42851be36b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 24 additions and 7 deletions

View File

@ -5,7 +5,7 @@ set -x
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
readonly CKPTS40_PATH="$GITHUB_WORKSPACE/llm_ckpts/40"
readonly CKPTS40_OUTPUT="${CKPTS40_PATH}/*.pt"
expected_num=21
expected_num=22
exit_code=0
source ./ci_scripts/common/basic_func.sh

View File

@ -5,7 +5,7 @@ set -x
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20"
readonly CKPTS20_OUTPUT="${CKPTS20_PATH}/*.pt"
expected_num=21
expected_num=22
exit_code=0
source ./ci_scripts/common/basic_func.sh

View File

@ -5,7 +5,7 @@ set -x
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20"
readonly CKPTS_OUTPUT="${CKPTS20_PATH}/*.pt"
expected_num=21
expected_num=22
exit_code=0
source ./ci_scripts/common/basic_func.sh

View File

@ -270,7 +270,7 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
class CheckpointManager:
"""StorageManagerContext"""
def __init__(self, ckpt_config, model, model_config, feishu_address=None) -> None:
def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None:
"""
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
upload mode, you must call wait_async_upload_finish at the end of the program to wait
@ -297,6 +297,7 @@ class CheckpointManager:
self.model = model
self.model_config = model_config
self.model_config_file = model_config_file
if self.stop_file_path and gpc.get_global_rank() == 0:
dir_path = os.path.dirname(self.stop_file_path)
@ -395,6 +396,7 @@ now step_count is {train_state.step_count}",
scheduler=self.lr_scheduler,
train_state=train_state,
model_config=self.model_config,
model_config_file=self.model_config_file,
)
return now_break
@ -558,7 +560,16 @@ set load_ckpt_folder or use default value \
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
def save_checkpoint(self, folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None):
def save_checkpoint(
self,
folder,
model,
optimizer,
scheduler,
train_state: TrainState,
model_config: Dict = None,
model_config_file: str = None,
):
"""
Save checkpoint to the given folder path.
"""
@ -599,8 +610,13 @@ set load_ckpt_folder or use default value \
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
if model_config is not None:
# Model configuration dictionary.
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
if model_config_file is not None:
# The complete training config file content, stored in binary format.
llm_save(os.path.join(folder, "config_file.pt"), saved_obj=model_config_file)
torch.distributed.barrier()
if gpc.is_rank_for_log():

View File

@ -102,10 +102,13 @@ def main(args):
# initialize model
model = initialize_model()
with open(args.config, "r") as f:
config_lines = f.readlines()
ckpt_manager = CheckpointManager(
ckpt_config=gpc.config.ckpt,
model=model,
model_config=gpc.config.model,
model_config_file="".join(config_lines),
feishu_address=gpc.config.alert_address,
)
@ -126,8 +129,6 @@ def main(args):
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
# initialize customed llm writer
with open(args.config, "r") as f:
config_lines = f.readlines()
writer = Writer(
job_name=gpc.config.JOB_NAME,
launch_time=current_time,