mirror of https://github.com/InternLM/InternLM
feat(ckpt): add train config into ckpt (#231)
parent
29dd401071
commit
42851be36b
|
@ -5,7 +5,7 @@ set -x
|
|||
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
|
||||
readonly CKPTS40_PATH="$GITHUB_WORKSPACE/llm_ckpts/40"
|
||||
readonly CKPTS40_OUTPUT="${CKPTS40_PATH}/*.pt"
|
||||
expected_num=21
|
||||
expected_num=22
|
||||
exit_code=0
|
||||
|
||||
source ./ci_scripts/common/basic_func.sh
|
||||
|
|
|
@ -5,7 +5,7 @@ set -x
|
|||
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
|
||||
readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20"
|
||||
readonly CKPTS20_OUTPUT="${CKPTS20_PATH}/*.pt"
|
||||
expected_num=21
|
||||
expected_num=22
|
||||
exit_code=0
|
||||
|
||||
source ./ci_scripts/common/basic_func.sh
|
||||
|
|
|
@ -5,7 +5,7 @@ set -x
|
|||
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
|
||||
readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20"
|
||||
readonly CKPTS_OUTPUT="${CKPTS20_PATH}/*.pt"
|
||||
expected_num=21
|
||||
expected_num=22
|
||||
exit_code=0
|
||||
|
||||
source ./ci_scripts/common/basic_func.sh
|
||||
|
|
|
@ -270,7 +270,7 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
|
|||
class CheckpointManager:
|
||||
"""StorageManagerContext"""
|
||||
|
||||
def __init__(self, ckpt_config, model, model_config, feishu_address=None) -> None:
|
||||
def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None:
|
||||
"""
|
||||
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
|
||||
upload mode, you must call wait_async_upload_finish at the end of the program to wait
|
||||
|
@ -297,6 +297,7 @@ class CheckpointManager:
|
|||
|
||||
self.model = model
|
||||
self.model_config = model_config
|
||||
self.model_config_file = model_config_file
|
||||
|
||||
if self.stop_file_path and gpc.get_global_rank() == 0:
|
||||
dir_path = os.path.dirname(self.stop_file_path)
|
||||
|
@ -395,6 +396,7 @@ now step_count is {train_state.step_count}",
|
|||
scheduler=self.lr_scheduler,
|
||||
train_state=train_state,
|
||||
model_config=self.model_config,
|
||||
model_config_file=self.model_config_file,
|
||||
)
|
||||
|
||||
return now_break
|
||||
|
@ -558,7 +560,16 @@ set load_ckpt_folder or use default value \
|
|||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
def save_checkpoint(self, folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None):
|
||||
def save_checkpoint(
|
||||
self,
|
||||
folder,
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
train_state: TrainState,
|
||||
model_config: Dict = None,
|
||||
model_config_file: str = None,
|
||||
):
|
||||
"""
|
||||
Save checkpoint to the given folder path.
|
||||
"""
|
||||
|
@ -599,8 +610,13 @@ set load_ckpt_folder or use default value \
|
|||
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
|
||||
|
||||
if model_config is not None:
|
||||
# Model configuration dictionary.
|
||||
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
|
||||
|
||||
if model_config_file is not None:
|
||||
# The complete training config file content, stored in binary format.
|
||||
llm_save(os.path.join(folder, "config_file.pt"), saved_obj=model_config_file)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
|
|
5
train.py
5
train.py
|
@ -102,10 +102,13 @@ def main(args):
|
|||
# initialize model
|
||||
model = initialize_model()
|
||||
|
||||
with open(args.config, "r") as f:
|
||||
config_lines = f.readlines()
|
||||
ckpt_manager = CheckpointManager(
|
||||
ckpt_config=gpc.config.ckpt,
|
||||
model=model,
|
||||
model_config=gpc.config.model,
|
||||
model_config_file="".join(config_lines),
|
||||
feishu_address=gpc.config.alert_address,
|
||||
)
|
||||
|
||||
|
@ -126,8 +129,6 @@ def main(args):
|
|||
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
|
||||
|
||||
# initialize customed llm writer
|
||||
with open(args.config, "r") as f:
|
||||
config_lines = f.readlines()
|
||||
writer = Writer(
|
||||
job_name=gpc.config.JOB_NAME,
|
||||
launch_time=current_time,
|
||||
|
|
Loading…
Reference in New Issue