merge upstream/develop into feature_add_moe

pull/182/head
Wenwen Qu 2023-09-11 16:27:33 +08:00
parent 8a595837fc
commit b46d1c17af
1 changed files with 24 additions and 22 deletions

View File

@ -35,6 +35,7 @@ from internlm.utils.common import (
parse_args, parse_args,
) )
from internlm.utils.evaluation import evaluate_on_val_dls from internlm.utils.evaluation import evaluate_on_val_dls
from internlm.utils.gputest import empty_cache_and_diag
from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import CheckpointManager from internlm.utils.model_checkpoint import CheckpointManager
@ -72,7 +73,6 @@ def main(args):
total_steps = gpc.config.data.total_steps total_steps = gpc.config.data.total_steps
valid_every = gpc.config.data.valid_every valid_every = gpc.config.data.valid_every
label_smoothing = gpc.config.loss.label_smoothing label_smoothing = gpc.config.loss.label_smoothing
lr = gpc.config.adam.lr
get_tflops_func = partial( get_tflops_func = partial(
get_megatron_flops, get_megatron_flops,
@ -95,21 +95,11 @@ def main(args):
# initialize customed llm logger # initialize customed llm logger
uniscale_logger = initialize_llm_logger(start_time=current_time) uniscale_logger = initialize_llm_logger(start_time=current_time)
# initialize and resume train state
train_state = TrainState(gpc.config)
# initialize model # initialize model
model = initialize_model() model = initialize_model()
with open(args.config, "r") as f: with open(args.config, "r") as f:
config_lines = f.readlines() config_lines = f.readlines()
ckpt_manager = CheckpointManager(
ckpt_config=gpc.config.ckpt,
model=model,
model_config=gpc.config.model,
model_config_file="".join(config_lines),
feishu_address=gpc.config.alert_address,
)
# initialize loss function # initialize loss function
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
@ -117,15 +107,25 @@ def main(args):
# initialize the train and validation data loader # initialize the train and validation data loader
train_dl, dataset_types = get_train_data_loader(num_worker=4) train_dl, dataset_types = get_train_data_loader(num_worker=4)
val_dls = get_validation_data_loader() val_dls = get_validation_data_loader()
train_state.init_batch_sampler(train_dl)
# Loading model weights must be done before zero is initialized. # initialize and resume train state
ckpt_manager.try_load_model(current_time) train_state = TrainState(gpc.config, train_dl.batch_sampler)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
ckpt_manager = CheckpointManager(
ckpt_config=gpc.config.ckpt,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
train_dl=train_dl,
model_config=gpc.config.model,
model_config_file="".join(config_lines),
feishu_address=gpc.config.monitor.alert.feishu_alert_address,
)
# Loading other persistent training states. # Loading other persistent training states.
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl) ckpt_manager.try_resume_training(train_state, current_time)
# initialize customed llm writer # initialize customed llm writer
writer = Writer( writer = Writer(
@ -194,9 +194,7 @@ def main(args):
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
# start iterating the train data and begin training # start iterating the train data and begin training
for batch_count in range(train_state.batch_count, total_steps): for batch_count in range(train_state.batch_count, total_steps):
if batch_count % 50 == 0: empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
torch.cuda.empty_cache()
start_time = time.time() start_time = time.time()
timer("one-batch").start() timer("one-batch").start()
@ -238,10 +236,10 @@ def main(args):
train_state.step_count += 1 train_state.step_count += 1
else: else:
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.") logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message( send_alert_message(
address=gpc.config.alert_address, address=gpc.config.monitor.alert.feishu_alert_address,
message=f"Warning: skip parameter update at step {batch_count}.", message=f"Warning: skip parameter update at step {batch_count}.",
) )
@ -302,11 +300,15 @@ if __name__ == "__main__":
assert hasattr(gpc, "config") and gpc.config is not None assert hasattr(gpc, "config") and gpc.config is not None
# initialize monitor manager context # initialize monitor manager context
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address): with initialize_monitor_manager(
job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address
):
try: try:
main(args) main(args)
except Exception: except Exception:
logger.error( logger.error(
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
) )
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc()) mm.monitor_exception(
alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc()
)