mirror of https://github.com/InternLM/InternLM
Merge branch 'feature_add_moe' of https://github.com/blankde/InternLM into feature_add_moe
commit
1fdc7107b4
|
@ -132,7 +132,6 @@ runs_bak/
|
||||||
LLM_ALERT
|
LLM_ALERT
|
||||||
small_demo/
|
small_demo/
|
||||||
7b_llama_nopp/
|
7b_llama_nopp/
|
||||||
test/
|
|
||||||
|
|
||||||
# Pytorch
|
# Pytorch
|
||||||
*.pth
|
*.pth
|
||||||
|
|
|
@ -112,7 +112,7 @@ def evaluate_on_val_dls(
|
||||||
tensor_shape=tensor_shape,
|
tensor_shape=tensor_shape,
|
||||||
metric_hook_list=[val_sche_metric_hook],
|
metric_hook_list=[val_sche_metric_hook],
|
||||||
):
|
):
|
||||||
_, _, loss, _ = trainer.execute_schedule(
|
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -126,11 +126,11 @@ def evaluate_on_val_dls(
|
||||||
grad_accum_batch_size=grad_accum_batch_size,
|
grad_accum_batch_size=grad_accum_batch_size,
|
||||||
metric_hook_list=[val_sche_metric_hook],
|
metric_hook_list=[val_sche_metric_hook],
|
||||||
):
|
):
|
||||||
_, _, loss, _ = trainer.execute_schedule(
|
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||||
)
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
val_loss += loss.item()
|
val_loss += loss.item() - moe_loss.item()
|
||||||
|
|
||||||
assert val_idx != -1
|
assert val_idx != -1
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
|
@ -186,11 +186,11 @@ def train(
|
||||||
# do forward and backward
|
# do forward and backward
|
||||||
timer("fwd-bwd").start()
|
timer("fwd-bwd").start()
|
||||||
|
|
||||||
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
|
_, _, loss, moe_loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
assert loss is not None and not math.isnan(loss.item())
|
assert loss is not None and not math.isnan(loss.item())
|
||||||
global cur_loss_list
|
global cur_loss_list
|
||||||
cur_loss_list.append(loss.item())
|
cur_loss_list.append(loss.item() - moe_loss.item())
|
||||||
timer("fwd-bwd").stop()
|
timer("fwd-bwd").stop()
|
||||||
|
|
||||||
# update parameters, and returns (success_update, grad_norm)
|
# update parameters, and returns (success_update, grad_norm)
|
||||||
|
|
Loading…
Reference in New Issue