change condition for compatibility

pull/182/head
Wenwen Qu 2023-09-27 12:36:03 +08:00
parent 591b4edb1d
commit f96764868d
7 changed files with 47 additions and 51 deletions

View File

@ -106,11 +106,11 @@ class NonPipelineScheduler(BaseScheduler):
# forward
with conditional_context(torch.no_grad(), enable=forward_only):
self._call_hooks("before_forward", data)
# moe_losses contains the loss of each layer
if gpc.config.get("model_type") == "INTERNLM":
output = self._call_engine(engine, data)
if gpc.config.get("model_type") == "INTERNLM_MoE":
if hasattr(gpc.config.model, "num_experts"):
# moe is used
output, moe_losses = self._call_engine(engine, data)
else:
output = self._call_engine(engine, data)
self._call_hooks("after_forward", output)
self._call_hooks("post_helper_func", output, label)
@ -121,7 +121,7 @@ class NonPipelineScheduler(BaseScheduler):
self._call_hooks("after_criterion", loss)
moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if gpc.config.get("model_type") == "INTERNLM_MoE"
if hasattr(gpc.config.model, "num_experts")
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
)
moe_loss /= scale_loss
@ -206,8 +206,8 @@ class NonPipelineScheduler(BaseScheduler):
if not return_output_label:
outputs, labels = None, None
# Compatible for old code
if gpc.config.get("model_type") == "INTERNLM":
return outputs, labels, loss
if gpc.config.get("model_type") == "INTERNLM_MoE":
# Compatible for non-moe
if hasattr(gpc.config.model, "num_experts"):
return outputs, labels, loss, moe_loss
else:
return outputs, labels, loss

View File

@ -275,11 +275,11 @@ class PipelineScheduler(BaseScheduler):
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data)
# moe_losses contains the loss of each layer in current stage
if gpc.config.get("model_type") == "INTERNLM":
output_obj = self._call_engine(engine.model, data)
if gpc.config.get("model_type") == "INTERNLM_MoE":
if hasattr(gpc.config.model, "num_experts"):
# moe is used
output_obj, moe_losses = self._call_engine(engine.model, data)
else:
output_obj = self._call_engine(engine.model, data)
self._call_hooks("after_forward", output_obj)
if gpc.is_last_rank(ParallelMode.PIPELINE):
@ -297,7 +297,7 @@ class PipelineScheduler(BaseScheduler):
moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if gpc.config.get("model_type") == "INTERNLM_MoE"
if hasattr(gpc.config.model, "num_experts")
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
)
moe_loss /= self.num_microbatches
@ -673,11 +673,11 @@ class PipelineScheduler(BaseScheduler):
engine, return_loss, return_output_label
)
# Compatible for old code
if gpc.config.get("model_type") == "INTERNLM":
return output, label, accum_loss
if gpc.config.get("model_type") == "INTERNLM_MoE":
# Compatible for non-moe
if hasattr(gpc.config.model, "num_experts"):
return output, label, accum_loss, accum_moe_loss
else:
return output, label, accum_loss
class InterleavedPipelineScheduler(PipelineScheduler):
@ -816,10 +816,10 @@ class InterleavedPipelineScheduler(PipelineScheduler):
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data)
if gpc.config.get("model_type") == "INTERNLM":
output_obj = self._call_engine(engine.model[chunk_id], data)
if gpc.config.get("model_type") == "INTERNLM_MoE":
if hasattr(gpc.config.model, "num_experts"):
output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data)
else:
output_obj = self._call_engine(engine.model[chunk_id], data)
# Convert output_obj to fp32 when last model chunk of last stage
if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
@ -841,7 +841,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if gpc.config.get("model_type") == "INTERNLM_MoE"
if hasattr(gpc.config.model, "num_experts")
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
)
moe_loss /= self.num_microbatches
@ -1378,8 +1378,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._clear_state()
# Compatible for old code
if gpc.config.get("model_type") == "INTERNLM":
return output, label, accum_loss
if gpc.config.get("model_type") == "INTERNLM_MoE":
# Compatible for non-moe
if hasattr(gpc.config.model, "num_experts"):
return output, label, accum_loss, accum_moe_loss
else:
return output, label, accum_loss

View File

@ -264,7 +264,7 @@ def args_sanity_check():
if "use_flash_attn" not in gpc.config.model:
gpc.config.model._add_item("use_flash_attn", True)
if gpc.config.get("model_type") == "INTERNLM_MoE":
if "MoE" in gpc.config.get("model_type", "INTERNLM"):
if "num_experts" not in model:
model._add_item("num_experts", 1)
if "moe_use_residual" not in model:

View File

@ -38,7 +38,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
# create new groups for fp32, norm, moe gate and moe expert
new_groups = {}
new_groups["fp32"] = {"name": "fp32", "params": []}
if gpc.config.get("model_type") == "INTERNLM_MoE" and gpc.config.model.num_experts > 1:
if gpc.config.model.get("num_experts", 0) > 1:
# norm and gate are special group to force sync (when enable MoE).
for key in ["gate", "norm"]:
new_groups[key] = {"name": key, key: True, "params": []}
@ -57,11 +57,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
# first split the norm and gate groups, which are special case to force sync (when enable MoE),
# then fp32 group and the moe group.
for param in pgroup["params"]:
if (
gpc.config.get("model_type") == "INTERNLM_MoE"
and gpc.config.model.num_experts > 1
and is_norm_param(param)
):
if gpc.config.model.get("num_experts", 0) > 1 and is_norm_param(param):
new_groups["norm"]["params"].append(param)
# gate param means MoE is enabled
elif is_gate_param(param):

View File

@ -113,13 +113,13 @@ def evaluate_on_val_dls(
tensor_shape=tensor_shape,
metric_hook_list=[val_sche_metric_hook],
):
# Compatible for old code
if gpc.config.get("model_type") == "INTERNLM":
_, _, loss = trainer.execute_schedule(
# Compatible for non-moe
if hasattr(gpc.config.model, "num_experts"):
_, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
elif gpc.config.get("model_type") == "INTERNLM_MoE":
_, _, loss, moe_loss = trainer.execute_schedule(
else:
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
else:
@ -133,12 +133,12 @@ def evaluate_on_val_dls(
grad_accum_batch_size=grad_accum_batch_size,
metric_hook_list=[val_sche_metric_hook],
):
if gpc.config.get("model_type") == "INTERNLM":
_, _, loss = trainer.execute_schedule(
if hasattr(gpc.config.model, "num_experts"):
_, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
elif gpc.config.get("model_type") == "INTERNLM_MoE":
_, _, loss, moe_loss = trainer.execute_schedule(
else:
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
if verbose:

View File

@ -186,14 +186,14 @@ def train(
# do forward and backward
timer("fwd-bwd").start()
# Compatible for old code
# Compatible for non-moe
moe_loss = None
if gpc.config.get("model_type") == "INTERNLM":
_, _, loss = trainer.execute_schedule(
if hasattr(gpc.config.model, "num_experts"):
_, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=False, return_loss=True, return_output_label=False
)
elif gpc.config.get("model_type") == "INTERNLM_MoE":
_, _, loss, moe_loss = trainer.execute_schedule(
else:
_, _, loss = trainer.execute_schedule(
batch, forward_only=False, return_loss=True, return_output_label=False
)
if gpc.is_rank_for_log():

View File

@ -220,15 +220,15 @@ def main(args):
timer("fwd-bwd").start()
moe_loss = None
if gpc.config.get("model_type") == "INTERNLM":
_, _, loss = trainer.execute_schedule(
if hasattr(gpc.config.model, "num_experts"):
_, _, loss, moe_loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)
if gpc.config.get("model_type") == "INTERNLM_MoE":
_, _, loss, moe_loss = trainer.execute_schedule(
else:
_, _, loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,