Merge branch 'feature_add_moe' of https://github.com/blankde/InternLM into feature_add_moe

pull/182/head
Wenwen Qu 2023-09-22 15:57:36 +08:00
commit 1fdc7107b4
4 changed files with 6 additions and 7 deletions

1
.gitignore vendored
View File

@ -132,7 +132,6 @@ runs_bak/
LLM_ALERT
small_demo/
7b_llama_nopp/
test/
# Pytorch
*.pth

View File

@ -112,7 +112,7 @@ def evaluate_on_val_dls(
tensor_shape=tensor_shape,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss, _ = trainer.execute_schedule(
_, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
else:
@ -126,11 +126,11 @@ def evaluate_on_val_dls(
grad_accum_batch_size=grad_accum_batch_size,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss, _ = trainer.execute_schedule(
_, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
if verbose:
val_loss += loss.item()
val_loss += loss.item() - moe_loss.item()
assert val_idx != -1
dist.barrier()

View File

@ -186,11 +186,11 @@ def train(
# do forward and backward
timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
_, _, loss, moe_loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
if gpc.is_rank_for_log():
assert loss is not None and not math.isnan(loss.item())
global cur_loss_list
cur_loss_list.append(loss.item())
cur_loss_list.append(loss.item() - moe_loss.item())
timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm)