restore train.py

pull/407/head
yingtongxiong 2023-10-09 20:08:49 +08:00
parent ef9e7cc622
commit 5d39c332fe
1 changed files with 17 additions and 5 deletions

View File

@ -110,6 +110,7 @@ def main(args):
# initialize and resume train state # initialize and resume train state
train_state = TrainState(gpc.config, train_dl.batch_sampler) train_state = TrainState(gpc.config, train_dl.batch_sampler)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
ckpt_manager = CheckpointManager( ckpt_manager = CheckpointManager(
@ -170,7 +171,6 @@ def main(args):
scheduler_hooks=scheduler_hooks, scheduler_hooks=scheduler_hooks,
) )
# initialize simple memory profiler # initialize simple memory profiler
if args.profiling: if args.profiling:
memory_profiler = SimpleMemoryProfiler( memory_profiler = SimpleMemoryProfiler(
@ -219,8 +219,20 @@ def main(args):
# do forward and backward # do forward and backward
timer("fwd-bwd").start() timer("fwd-bwd").start()
moe_loss = None
if hasattr(gpc.config.model, "num_experts"):
_, _, loss, moe_loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)
else:
_, _, loss = trainer.execute_schedule( _, _, loss = trainer.execute_schedule(
batch, forward_only=False, return_loss=True, return_output_label=False batch,
forward_only=False,
return_loss=True,
return_output_label=False,
) )
timer("fwd-bwd").stop() timer("fwd-bwd").stop()
@ -254,7 +266,7 @@ def main(args):
trainer=trainer, trainer=trainer,
start_time=start_time, start_time=start_time,
loss=loss, loss=loss,
moe_loss=None, moe_loss=moe_loss,
grad_norm=grad_norm_groups, grad_norm=grad_norm_groups,
metric=metric, metric=metric,
update_panel=uniscale_logger is not None, update_panel=uniscale_logger is not None,