feat(ckpt): add train config into ckpt (#231)

pull/233/head
Guoteng 2023-08-24 19:57:32 +08:00 committed by GitHub
parent 29dd401071
commit 42851be36b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 24 additions and 7 deletions

View File

@ -5,7 +5,7 @@ set -x
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts" readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
readonly CKPTS40_PATH="$GITHUB_WORKSPACE/llm_ckpts/40" readonly CKPTS40_PATH="$GITHUB_WORKSPACE/llm_ckpts/40"
readonly CKPTS40_OUTPUT="${CKPTS40_PATH}/*.pt" readonly CKPTS40_OUTPUT="${CKPTS40_PATH}/*.pt"
expected_num=21 expected_num=22
exit_code=0 exit_code=0
source ./ci_scripts/common/basic_func.sh source ./ci_scripts/common/basic_func.sh

View File

@ -5,7 +5,7 @@ set -x
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts" readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20" readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20"
readonly CKPTS20_OUTPUT="${CKPTS20_PATH}/*.pt" readonly CKPTS20_OUTPUT="${CKPTS20_PATH}/*.pt"
expected_num=21 expected_num=22
exit_code=0 exit_code=0
source ./ci_scripts/common/basic_func.sh source ./ci_scripts/common/basic_func.sh

View File

@ -5,7 +5,7 @@ set -x
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts" readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20" readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20"
readonly CKPTS_OUTPUT="${CKPTS20_PATH}/*.pt" readonly CKPTS_OUTPUT="${CKPTS20_PATH}/*.pt"
expected_num=21 expected_num=22
exit_code=0 exit_code=0
source ./ci_scripts/common/basic_func.sh source ./ci_scripts/common/basic_func.sh

View File

@ -270,7 +270,7 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
class CheckpointManager: class CheckpointManager:
"""StorageManagerContext""" """StorageManagerContext"""
def __init__(self, ckpt_config, model, model_config, feishu_address=None) -> None: def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None:
""" """
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
upload mode, you must call wait_async_upload_finish at the end of the program to wait upload mode, you must call wait_async_upload_finish at the end of the program to wait
@ -297,6 +297,7 @@ class CheckpointManager:
self.model = model self.model = model
self.model_config = model_config self.model_config = model_config
self.model_config_file = model_config_file
if self.stop_file_path and gpc.get_global_rank() == 0: if self.stop_file_path and gpc.get_global_rank() == 0:
dir_path = os.path.dirname(self.stop_file_path) dir_path = os.path.dirname(self.stop_file_path)
@ -395,6 +396,7 @@ now step_count is {train_state.step_count}",
scheduler=self.lr_scheduler, scheduler=self.lr_scheduler,
train_state=train_state, train_state=train_state,
model_config=self.model_config, model_config=self.model_config,
model_config_file=self.model_config_file,
) )
return now_break return now_break
@ -558,7 +560,16 @@ set load_ckpt_folder or use default value \
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
def save_checkpoint(self, folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None): def save_checkpoint(
self,
folder,
model,
optimizer,
scheduler,
train_state: TrainState,
model_config: Dict = None,
model_config_file: str = None,
):
""" """
Save checkpoint to the given folder path. Save checkpoint to the given folder path.
""" """
@ -599,8 +610,13 @@ set load_ckpt_folder or use default value \
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict()) llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
if model_config is not None: if model_config is not None:
# Model configuration dictionary.
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config) llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
if model_config_file is not None:
# The complete training config file content, stored in binary format.
llm_save(os.path.join(folder, "config_file.pt"), saved_obj=model_config_file)
torch.distributed.barrier() torch.distributed.barrier()
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():

View File

@ -102,10 +102,13 @@ def main(args):
# initialize model # initialize model
model = initialize_model() model = initialize_model()
with open(args.config, "r") as f:
config_lines = f.readlines()
ckpt_manager = CheckpointManager( ckpt_manager = CheckpointManager(
ckpt_config=gpc.config.ckpt, ckpt_config=gpc.config.ckpt,
model=model, model=model,
model_config=gpc.config.model, model_config=gpc.config.model,
model_config_file="".join(config_lines),
feishu_address=gpc.config.alert_address, feishu_address=gpc.config.alert_address,
) )
@ -126,8 +129,6 @@ def main(args):
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl) ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
# initialize customed llm writer # initialize customed llm writer
with open(args.config, "r") as f:
config_lines = f.readlines()
writer = Writer( writer = Writer(
job_name=gpc.config.JOB_NAME, job_name=gpc.config.JOB_NAME,
launch_time=current_time, launch_time=current_time,