fix the moe_loss for ci and val

pull/182/head
zhanglei 2023-09-22 15:45:36 +08:00
parent 3df0a51555
commit ccdaf8ec45
2 changed files with 4 additions and 4 deletions

View File

@ -112,7 +112,7 @@ def evaluate_on_val_dls(
tensor_shape=tensor_shape,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss, _ = trainer.execute_schedule(
_, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
else:
@ -126,11 +126,11 @@ def evaluate_on_val_dls(
grad_accum_batch_size=grad_accum_batch_size,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss, _ = trainer.execute_schedule(
_, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
if verbose:
val_loss += loss.item()
val_loss += loss.item() - moe_loss.item()
assert val_idx != -1
dist.barrier()

View File

@ -186,7 +186,7 @@ def train(
# do forward and backward
timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
_, _, loss, _ = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
if gpc.is_rank_for_log():
assert loss is not None and not math.isnan(loss.item())
global cur_loss_list