mirror of https://github.com/InternLM/InternLM
merge upstream/develop into feature_add_moe
parent
8a595837fc
commit
b46d1c17af
46
train.py
46
train.py
|
@ -35,6 +35,7 @@ from internlm.utils.common import (
|
||||||
parse_args,
|
parse_args,
|
||||||
)
|
)
|
||||||
from internlm.utils.evaluation import evaluate_on_val_dls
|
from internlm.utils.evaluation import evaluate_on_val_dls
|
||||||
|
from internlm.utils.gputest import empty_cache_and_diag
|
||||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.utils.model_checkpoint import CheckpointManager
|
from internlm.utils.model_checkpoint import CheckpointManager
|
||||||
|
@ -72,7 +73,6 @@ def main(args):
|
||||||
total_steps = gpc.config.data.total_steps
|
total_steps = gpc.config.data.total_steps
|
||||||
valid_every = gpc.config.data.valid_every
|
valid_every = gpc.config.data.valid_every
|
||||||
label_smoothing = gpc.config.loss.label_smoothing
|
label_smoothing = gpc.config.loss.label_smoothing
|
||||||
lr = gpc.config.adam.lr
|
|
||||||
|
|
||||||
get_tflops_func = partial(
|
get_tflops_func = partial(
|
||||||
get_megatron_flops,
|
get_megatron_flops,
|
||||||
|
@ -95,21 +95,11 @@ def main(args):
|
||||||
# initialize customed llm logger
|
# initialize customed llm logger
|
||||||
uniscale_logger = initialize_llm_logger(start_time=current_time)
|
uniscale_logger = initialize_llm_logger(start_time=current_time)
|
||||||
|
|
||||||
# initialize and resume train state
|
|
||||||
train_state = TrainState(gpc.config)
|
|
||||||
|
|
||||||
# initialize model
|
# initialize model
|
||||||
model = initialize_model()
|
model = initialize_model()
|
||||||
|
|
||||||
with open(args.config, "r") as f:
|
with open(args.config, "r") as f:
|
||||||
config_lines = f.readlines()
|
config_lines = f.readlines()
|
||||||
ckpt_manager = CheckpointManager(
|
|
||||||
ckpt_config=gpc.config.ckpt,
|
|
||||||
model=model,
|
|
||||||
model_config=gpc.config.model,
|
|
||||||
model_config_file="".join(config_lines),
|
|
||||||
feishu_address=gpc.config.alert_address,
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize loss function
|
# initialize loss function
|
||||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
||||||
|
@ -117,15 +107,25 @@ def main(args):
|
||||||
# initialize the train and validation data loader
|
# initialize the train and validation data loader
|
||||||
train_dl, dataset_types = get_train_data_loader(num_worker=4)
|
train_dl, dataset_types = get_train_data_loader(num_worker=4)
|
||||||
val_dls = get_validation_data_loader()
|
val_dls = get_validation_data_loader()
|
||||||
train_state.init_batch_sampler(train_dl)
|
|
||||||
|
|
||||||
# Loading model weights must be done before zero is initialized.
|
# initialize and resume train state
|
||||||
ckpt_manager.try_load_model(current_time)
|
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
||||||
|
|
||||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||||
|
|
||||||
|
ckpt_manager = CheckpointManager(
|
||||||
|
ckpt_config=gpc.config.ckpt,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
train_dl=train_dl,
|
||||||
|
model_config=gpc.config.model,
|
||||||
|
model_config_file="".join(config_lines),
|
||||||
|
feishu_address=gpc.config.monitor.alert.feishu_alert_address,
|
||||||
|
)
|
||||||
|
|
||||||
# Loading other persistent training states.
|
# Loading other persistent training states.
|
||||||
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
|
ckpt_manager.try_resume_training(train_state, current_time)
|
||||||
|
|
||||||
# initialize customed llm writer
|
# initialize customed llm writer
|
||||||
writer = Writer(
|
writer = Writer(
|
||||||
|
@ -194,9 +194,7 @@ def main(args):
|
||||||
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
|
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
|
||||||
# start iterating the train data and begin training
|
# start iterating the train data and begin training
|
||||||
for batch_count in range(train_state.batch_count, total_steps):
|
for batch_count in range(train_state.batch_count, total_steps):
|
||||||
if batch_count % 50 == 0:
|
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
timer("one-batch").start()
|
timer("one-batch").start()
|
||||||
|
|
||||||
|
@ -238,10 +236,10 @@ def main(args):
|
||||||
train_state.step_count += 1
|
train_state.step_count += 1
|
||||||
else:
|
else:
|
||||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||||
if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case
|
if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case
|
||||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||||
send_alert_message(
|
send_alert_message(
|
||||||
address=gpc.config.alert_address,
|
address=gpc.config.monitor.alert.feishu_alert_address,
|
||||||
message=f"Warning: skip parameter update at step {batch_count}.",
|
message=f"Warning: skip parameter update at step {batch_count}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -302,11 +300,15 @@ if __name__ == "__main__":
|
||||||
assert hasattr(gpc, "config") and gpc.config is not None
|
assert hasattr(gpc, "config") and gpc.config is not None
|
||||||
|
|
||||||
# initialize monitor manager context
|
# initialize monitor manager context
|
||||||
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
|
with initialize_monitor_manager(
|
||||||
|
job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
|
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
|
||||||
)
|
)
|
||||||
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
|
mm.monitor_exception(
|
||||||
|
alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc()
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue