mirror of https://github.com/InternLM/InternLM
restore train.py
parent
ef9e7cc622
commit
5d39c332fe
22
train.py
22
train.py
|
@ -110,6 +110,7 @@ def main(args):
|
||||||
|
|
||||||
# initialize and resume train state
|
# initialize and resume train state
|
||||||
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
||||||
|
|
||||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||||
|
|
||||||
ckpt_manager = CheckpointManager(
|
ckpt_manager = CheckpointManager(
|
||||||
|
@ -169,7 +170,6 @@ def main(args):
|
||||||
beta2_scheduler=beta2_scheduler,
|
beta2_scheduler=beta2_scheduler,
|
||||||
scheduler_hooks=scheduler_hooks,
|
scheduler_hooks=scheduler_hooks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# initialize simple memory profiler
|
# initialize simple memory profiler
|
||||||
if args.profiling:
|
if args.profiling:
|
||||||
|
@ -219,9 +219,21 @@ def main(args):
|
||||||
# do forward and backward
|
# do forward and backward
|
||||||
timer("fwd-bwd").start()
|
timer("fwd-bwd").start()
|
||||||
|
|
||||||
_, _, loss = trainer.execute_schedule(
|
moe_loss = None
|
||||||
batch, forward_only=False, return_loss=True, return_output_label=False
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
)
|
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||||
|
batch,
|
||||||
|
forward_only=False,
|
||||||
|
return_loss=True,
|
||||||
|
return_output_label=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_, _, loss = trainer.execute_schedule(
|
||||||
|
batch,
|
||||||
|
forward_only=False,
|
||||||
|
return_loss=True,
|
||||||
|
return_output_label=False,
|
||||||
|
)
|
||||||
timer("fwd-bwd").stop()
|
timer("fwd-bwd").stop()
|
||||||
|
|
||||||
# update parameters, and returns (success_update, grad_norm)
|
# update parameters, and returns (success_update, grad_norm)
|
||||||
|
@ -254,7 +266,7 @@ def main(args):
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
moe_loss=None,
|
moe_loss=moe_loss,
|
||||||
grad_norm=grad_norm_groups,
|
grad_norm=grad_norm_groups,
|
||||||
metric=metric,
|
metric=metric,
|
||||||
update_panel=uniscale_logger is not None,
|
update_panel=uniscale_logger is not None,
|
||||||
|
|
Loading…
Reference in New Issue