|
|
|
@ -269,14 +269,11 @@ def main():
|
|
|
|
|
# Backward and optimize |
|
|
|
|
if is_pp_last_stage: |
|
|
|
|
loss = outputs["loss"] |
|
|
|
|
# aux_loss = outputs["outputs"]["aux_loss"] |
|
|
|
|
|
|
|
|
|
global_loss = get_global_loss(loss, booster) |
|
|
|
|
# global_aux_loss = get_global_loss(aux_loss, booster) |
|
|
|
|
if coordinator._local_rank == "0": |
|
|
|
|
pbar.set_postfix({"Loss": global_loss.item()}) |
|
|
|
|
writer.add_scalar(tag="Loss", scalar_value=global_loss.item(), global_step=step) |
|
|
|
|
# writer.add_scalar(tag="Aux Loss", scalar_value=global_aux_loss.item(), global_step=step) |
|
|
|
|
writer.add_scalar( |
|
|
|
|
tag="Learning Rate", |
|
|
|
|
scalar_value=lr_scheduler.get_last_lr()[0], |
|
|
|
@ -320,7 +317,7 @@ def main():
|
|
|
|
|
optimizer, |
|
|
|
|
lr_scheduler, |
|
|
|
|
epoch, |
|
|
|
|
step, |
|
|
|
|
step + 1, |
|
|
|
|
args.batch_size, |
|
|
|
|
coordinator, |
|
|
|
|
) |
|
|
|
|