mirror of https://github.com/InternLM/InternLM
change condition for compatibility
parent
591b4edb1d
commit
f96764868d
|
@ -106,11 +106,11 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
# forward
|
||||
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||
self._call_hooks("before_forward", data)
|
||||
# moe_losses contains the loss of each layer
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
output = self._call_engine(engine, data)
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
# moe is used
|
||||
output, moe_losses = self._call_engine(engine, data)
|
||||
else:
|
||||
output = self._call_engine(engine, data)
|
||||
self._call_hooks("after_forward", output)
|
||||
|
||||
self._call_hooks("post_helper_func", output, label)
|
||||
|
@ -121,7 +121,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
self._call_hooks("after_criterion", loss)
|
||||
moe_loss = (
|
||||
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE"
|
||||
if hasattr(gpc.config.model, "num_experts")
|
||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||
)
|
||||
moe_loss /= scale_loss
|
||||
|
@ -206,8 +206,8 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
if not return_output_label:
|
||||
outputs, labels = None, None
|
||||
|
||||
# Compatible for old code
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
return outputs, labels, loss
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
# Compatible for non-moe
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
return outputs, labels, loss, moe_loss
|
||||
else:
|
||||
return outputs, labels, loss
|
||||
|
|
|
@ -275,11 +275,11 @@ class PipelineScheduler(BaseScheduler):
|
|||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
|
||||
|
||||
self._call_hooks("before_forward", data)
|
||||
# moe_losses contains the loss of each layer in current stage
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
output_obj = self._call_engine(engine.model, data)
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
# moe is used
|
||||
output_obj, moe_losses = self._call_engine(engine.model, data)
|
||||
else:
|
||||
output_obj = self._call_engine(engine.model, data)
|
||||
self._call_hooks("after_forward", output_obj)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
|
@ -297,7 +297,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
moe_loss = (
|
||||
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE"
|
||||
if hasattr(gpc.config.model, "num_experts")
|
||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||
)
|
||||
moe_loss /= self.num_microbatches
|
||||
|
@ -673,11 +673,11 @@ class PipelineScheduler(BaseScheduler):
|
|||
engine, return_loss, return_output_label
|
||||
)
|
||||
|
||||
# Compatible for old code
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
return output, label, accum_loss
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
# Compatible for non-moe
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
return output, label, accum_loss, accum_moe_loss
|
||||
else:
|
||||
return output, label, accum_loss
|
||||
|
||||
|
||||
class InterleavedPipelineScheduler(PipelineScheduler):
|
||||
|
@ -816,10 +816,10 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
|
||||
|
||||
self._call_hooks("before_forward", data)
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
output_obj = self._call_engine(engine.model[chunk_id], data)
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data)
|
||||
else:
|
||||
output_obj = self._call_engine(engine.model[chunk_id], data)
|
||||
# Convert output_obj to fp32 when last model chunk of last stage
|
||||
if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
|
||||
output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
|
||||
|
@ -841,7 +841,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
|
||||
moe_loss = (
|
||||
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE"
|
||||
if hasattr(gpc.config.model, "num_experts")
|
||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||
)
|
||||
moe_loss /= self.num_microbatches
|
||||
|
@ -1378,8 +1378,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
|
||||
self._clear_state()
|
||||
|
||||
# Compatible for old code
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
return output, label, accum_loss
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
# Compatible for non-moe
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
return output, label, accum_loss, accum_moe_loss
|
||||
else:
|
||||
return output, label, accum_loss
|
||||
|
|
|
@ -264,7 +264,7 @@ def args_sanity_check():
|
|||
if "use_flash_attn" not in gpc.config.model:
|
||||
gpc.config.model._add_item("use_flash_attn", True)
|
||||
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
if "MoE" in gpc.config.get("model_type", "INTERNLM"):
|
||||
if "num_experts" not in model:
|
||||
model._add_item("num_experts", 1)
|
||||
if "moe_use_residual" not in model:
|
||||
|
|
|
@ -38,7 +38,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
|||
# create new groups for fp32, norm, moe gate and moe expert
|
||||
new_groups = {}
|
||||
new_groups["fp32"] = {"name": "fp32", "params": []}
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE" and gpc.config.model.num_experts > 1:
|
||||
if gpc.config.model.get("num_experts", 0) > 1:
|
||||
# norm and gate are special group to force sync (when enable MoE).
|
||||
for key in ["gate", "norm"]:
|
||||
new_groups[key] = {"name": key, key: True, "params": []}
|
||||
|
@ -57,11 +57,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
|||
# first split the norm and gate groups, which are special case to force sync (when enable MoE),
|
||||
# then fp32 group and the moe group.
|
||||
for param in pgroup["params"]:
|
||||
if (
|
||||
gpc.config.get("model_type") == "INTERNLM_MoE"
|
||||
and gpc.config.model.num_experts > 1
|
||||
and is_norm_param(param)
|
||||
):
|
||||
if gpc.config.model.get("num_experts", 0) > 1 and is_norm_param(param):
|
||||
new_groups["norm"]["params"].append(param)
|
||||
# gate param means MoE is enabled
|
||||
elif is_gate_param(param):
|
||||
|
|
|
@ -113,13 +113,13 @@ def evaluate_on_val_dls(
|
|||
tensor_shape=tensor_shape,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
# Compatible for old code
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
# Compatible for non-moe
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
elif gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
else:
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
else:
|
||||
|
@ -133,12 +133,12 @@ def evaluate_on_val_dls(
|
|||
grad_accum_batch_size=grad_accum_batch_size,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
elif gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
else:
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
if verbose:
|
||||
|
|
|
@ -186,14 +186,14 @@ def train(
|
|||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
|
||||
# Compatible for old code
|
||||
# Compatible for non-moe
|
||||
moe_loss = None
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch, forward_only=False, return_loss=True, return_output_label=False
|
||||
)
|
||||
elif gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
else:
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=False, return_loss=True, return_output_label=False
|
||||
)
|
||||
if gpc.is_rank_for_log():
|
||||
|
|
8
train.py
8
train.py
|
@ -220,15 +220,15 @@ def main(args):
|
|||
timer("fwd-bwd").start()
|
||||
|
||||
moe_loss = None
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=False,
|
||||
)
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
else:
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
|
|
Loading…
Reference in New Issue