Merge branch 'feature_add_moe' of https://github.com/blankde/InternLM into feature_add_moe

pull/182/head
Wenwen Qu 2023-09-22 15:57:36 +08:00
commit 1fdc7107b4
4 changed files with 6 additions and 7 deletions

1
.gitignore vendored
View File

@ -132,7 +132,6 @@ runs_bak/
LLM_ALERT LLM_ALERT
small_demo/ small_demo/
7b_llama_nopp/ 7b_llama_nopp/
test/
# Pytorch # Pytorch
*.pth *.pth

View File

@ -112,7 +112,7 @@ def evaluate_on_val_dls(
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
metric_hook_list=[val_sche_metric_hook], metric_hook_list=[val_sche_metric_hook],
): ):
_, _, loss, _ = trainer.execute_schedule( _, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False batch, forward_only=True, return_loss=True, return_output_label=False
) )
else: else:
@ -126,11 +126,11 @@ def evaluate_on_val_dls(
grad_accum_batch_size=grad_accum_batch_size, grad_accum_batch_size=grad_accum_batch_size,
metric_hook_list=[val_sche_metric_hook], metric_hook_list=[val_sche_metric_hook],
): ):
_, _, loss, _ = trainer.execute_schedule( _, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False batch, forward_only=True, return_loss=True, return_output_label=False
) )
if verbose: if verbose:
val_loss += loss.item() val_loss += loss.item() - moe_loss.item()
assert val_idx != -1 assert val_idx != -1
dist.barrier() dist.barrier()

View File

@ -186,11 +186,11 @@ def train(
# do forward and backward # do forward and backward
timer("fwd-bwd").start() timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False) _, _, loss, moe_loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
assert loss is not None and not math.isnan(loss.item()) assert loss is not None and not math.isnan(loss.item())
global cur_loss_list global cur_loss_list
cur_loss_list.append(loss.item()) cur_loss_list.append(loss.item() - moe_loss.item())
timer("fwd-bwd").stop() timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm) # update parameters, and returns (success_update, grad_norm)