update trainer_result in ci

pull/412/head
JiaoPL 2023-10-18 19:37:15 +08:00
parent 6ce78a4e09
commit 5d0151d7b0
1 changed files with 1 additions and 3 deletions

View File

@ -1,6 +1,5 @@
import math
import os
import subprocess
import pytest
import torch
@ -198,7 +197,6 @@ def train(
)
if gpc.is_rank_for_log():
assert loss is not None and not math.isnan(loss.item())
global cur_loss_list
cur_loss_list.append((loss.item() - moe_loss.item() if moe_loss is not None else loss.item()))
timer("fwd-bwd").stop()
@ -206,7 +204,7 @@ def train(
trainer_result = trainer.step()
assert trainer_result is not None
success_update, _ = trainer_result
success_update, _, _ = trainer_result
assert success_update, "Error: grad norm inf or nan occurs!"
if success_update: # update parameters successfully
train_state.step_count += 1