restore train.py

pull/407/head
yingtongxiong 2023-10-09 20:08:49 +08:00
parent ef9e7cc622
commit 5d39c332fe
1 changed files with 17 additions and 5 deletions

View File

@ -110,6 +110,7 @@ def main(args):
# initialize and resume train state # initialize and resume train state
train_state = TrainState(gpc.config, train_dl.batch_sampler) train_state = TrainState(gpc.config, train_dl.batch_sampler)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
ckpt_manager = CheckpointManager( ckpt_manager = CheckpointManager(
@ -169,7 +170,6 @@ def main(args):
beta2_scheduler=beta2_scheduler, beta2_scheduler=beta2_scheduler,
scheduler_hooks=scheduler_hooks, scheduler_hooks=scheduler_hooks,
) )
# initialize simple memory profiler # initialize simple memory profiler
if args.profiling: if args.profiling:
@ -219,9 +219,21 @@ def main(args):
# do forward and backward # do forward and backward
timer("fwd-bwd").start() timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule( moe_loss = None
batch, forward_only=False, return_loss=True, return_output_label=False if hasattr(gpc.config.model, "num_experts"):
) _, _, loss, moe_loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)
else:
_, _, loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)
timer("fwd-bwd").stop() timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm) # update parameters, and returns (success_update, grad_norm)
@ -254,7 +266,7 @@ def main(args):
trainer=trainer, trainer=trainer,
start_time=start_time, start_time=start_time,
loss=loss, loss=loss,
moe_loss=None, moe_loss=moe_loss,
grad_norm=grad_norm_groups, grad_norm=grad_norm_groups,
metric=metric, metric=metric,
update_panel=uniscale_logger is not None, update_panel=uniscale_logger is not None,