mirror of https://github.com/InternLM/InternLM
merge upstream/develop into feature_add_moe
parent
8a595837fc
commit
b46d1c17af
46
train.py
46
train.py
|
@ -35,6 +35,7 @@ from internlm.utils.common import (
|
|||
parse_args,
|
||||
)
|
||||
from internlm.utils.evaluation import evaluate_on_val_dls
|
||||
from internlm.utils.gputest import empty_cache_and_diag
|
||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.model_checkpoint import CheckpointManager
|
||||
|
@ -72,7 +73,6 @@ def main(args):
|
|||
total_steps = gpc.config.data.total_steps
|
||||
valid_every = gpc.config.data.valid_every
|
||||
label_smoothing = gpc.config.loss.label_smoothing
|
||||
lr = gpc.config.adam.lr
|
||||
|
||||
get_tflops_func = partial(
|
||||
get_megatron_flops,
|
||||
|
@ -95,21 +95,11 @@ def main(args):
|
|||
# initialize customed llm logger
|
||||
uniscale_logger = initialize_llm_logger(start_time=current_time)
|
||||
|
||||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config)
|
||||
|
||||
# initialize model
|
||||
model = initialize_model()
|
||||
|
||||
with open(args.config, "r") as f:
|
||||
config_lines = f.readlines()
|
||||
ckpt_manager = CheckpointManager(
|
||||
ckpt_config=gpc.config.ckpt,
|
||||
model=model,
|
||||
model_config=gpc.config.model,
|
||||
model_config_file="".join(config_lines),
|
||||
feishu_address=gpc.config.alert_address,
|
||||
)
|
||||
|
||||
# initialize loss function
|
||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
||||
|
@ -117,15 +107,25 @@ def main(args):
|
|||
# initialize the train and validation data loader
|
||||
train_dl, dataset_types = get_train_data_loader(num_worker=4)
|
||||
val_dls = get_validation_data_loader()
|
||||
train_state.init_batch_sampler(train_dl)
|
||||
|
||||
# Loading model weights must be done before zero is initialized.
|
||||
ckpt_manager.try_load_model(current_time)
|
||||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
||||
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||
|
||||
ckpt_manager = CheckpointManager(
|
||||
ckpt_config=gpc.config.ckpt,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
train_dl=train_dl,
|
||||
model_config=gpc.config.model,
|
||||
model_config_file="".join(config_lines),
|
||||
feishu_address=gpc.config.monitor.alert.feishu_alert_address,
|
||||
)
|
||||
|
||||
# Loading other persistent training states.
|
||||
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
|
||||
ckpt_manager.try_resume_training(train_state, current_time)
|
||||
|
||||
# initialize customed llm writer
|
||||
writer = Writer(
|
||||
|
@ -194,9 +194,7 @@ def main(args):
|
|||
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
|
||||
# start iterating the train data and begin training
|
||||
for batch_count in range(train_state.batch_count, total_steps):
|
||||
if batch_count % 50 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
|
||||
start_time = time.time()
|
||||
timer("one-batch").start()
|
||||
|
||||
|
@ -238,10 +236,10 @@ def main(args):
|
|||
train_state.step_count += 1
|
||||
else:
|
||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||
if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case
|
||||
if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address,
|
||||
address=gpc.config.monitor.alert.feishu_alert_address,
|
||||
message=f"Warning: skip parameter update at step {batch_count}.",
|
||||
)
|
||||
|
||||
|
@ -302,11 +300,15 @@ if __name__ == "__main__":
|
|||
assert hasattr(gpc, "config") and gpc.config is not None
|
||||
|
||||
# initialize monitor manager context
|
||||
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
|
||||
with initialize_monitor_manager(
|
||||
job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address
|
||||
):
|
||||
try:
|
||||
main(args)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
|
||||
)
|
||||
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
|
||||
mm.monitor_exception(
|
||||
alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc()
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue