diff --git a/train.py b/train.py index 1adcc22..139bac1 100644 --- a/train.py +++ b/train.py @@ -110,6 +110,7 @@ def main(args): # initialize and resume train state train_state = TrainState(gpc.config, train_dl.batch_sampler) + optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) ckpt_manager = CheckpointManager( @@ -169,7 +170,6 @@ def main(args): beta2_scheduler=beta2_scheduler, scheduler_hooks=scheduler_hooks, ) - # initialize simple memory profiler if args.profiling: @@ -219,9 +219,21 @@ def main(args): # do forward and backward timer("fwd-bwd").start() - _, _, loss = trainer.execute_schedule( - batch, forward_only=False, return_loss=True, return_output_label=False - ) + moe_loss = None + if hasattr(gpc.config.model, "num_experts"): + _, _, loss, moe_loss = trainer.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) + else: + _, _, loss = trainer.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) timer("fwd-bwd").stop() # update parameters, and returns (success_update, grad_norm) @@ -254,7 +266,7 @@ def main(args): trainer=trainer, start_time=start_time, loss=loss, - moe_loss=None, + moe_loss=moe_loss, grad_norm=grad_norm_groups, metric=metric, update_panel=uniscale_logger is not None,