restore train.py

pull/407/head
yingtongxiong 2023-10-09 20:08:49 +08:00
parent ef9e7cc622
commit 5d39c332fe
1 changed files with 17 additions and 5 deletions

View File

@ -110,6 +110,7 @@ def main(args):
# initialize and resume train state
train_state = TrainState(gpc.config, train_dl.batch_sampler)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
ckpt_manager = CheckpointManager(
@ -169,7 +170,6 @@ def main(args):
beta2_scheduler=beta2_scheduler,
scheduler_hooks=scheduler_hooks,
)
# initialize simple memory profiler
if args.profiling:
@ -219,9 +219,21 @@ def main(args):
# do forward and backward
timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(
batch, forward_only=False, return_loss=True, return_output_label=False
)
moe_loss = None
if hasattr(gpc.config.model, "num_experts"):
_, _, loss, moe_loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)
else:
_, _, loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)
timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm)
@ -254,7 +266,7 @@ def main(args):
trainer=trainer,
start_time=start_time,
loss=loss,
moe_loss=None,
moe_loss=moe_loss,
grad_norm=grad_norm_groups,
metric=metric,
update_panel=uniscale_logger is not None,