mirror of https://github.com/InternLM/InternLM
restore train.py
parent
ef9e7cc622
commit
5d39c332fe
22
train.py
22
train.py
|
@ -110,6 +110,7 @@ def main(args):
|
|||
|
||||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
||||
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||
|
||||
ckpt_manager = CheckpointManager(
|
||||
|
@ -169,7 +170,6 @@ def main(args):
|
|||
beta2_scheduler=beta2_scheduler,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
)
|
||||
|
||||
|
||||
# initialize simple memory profiler
|
||||
if args.profiling:
|
||||
|
@ -219,9 +219,21 @@ def main(args):
|
|||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=False, return_loss=True, return_output_label=False
|
||||
)
|
||||
moe_loss = None
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=False,
|
||||
)
|
||||
else:
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=False,
|
||||
)
|
||||
timer("fwd-bwd").stop()
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
|
@ -254,7 +266,7 @@ def main(args):
|
|||
trainer=trainer,
|
||||
start_time=start_time,
|
||||
loss=loss,
|
||||
moe_loss=None,
|
||||
moe_loss=moe_loss,
|
||||
grad_norm=grad_norm_groups,
|
||||
metric=metric,
|
||||
update_panel=uniscale_logger is not None,
|
||||
|
|
Loading…
Reference in New Issue