From f96764868da58324cd505e8e2a4cd51fb6744fcb Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Wed, 27 Sep 2023 12:36:03 +0800 Subject: [PATCH] change condition for compatibility --- .../core/scheduler/no_pipeline_scheduler.py | 18 +++++----- internlm/core/scheduler/pipeline_scheduler.py | 34 +++++++++---------- internlm/initialize/launch.py | 2 +- internlm/train/utils.py | 8 ++--- internlm/utils/evaluation.py | 18 +++++----- tests/test_training/test_loss.py | 10 +++--- train.py | 8 ++--- 7 files changed, 47 insertions(+), 51 deletions(-) diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 3de3d45..56661d8 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -106,11 +106,11 @@ class NonPipelineScheduler(BaseScheduler): # forward with conditional_context(torch.no_grad(), enable=forward_only): self._call_hooks("before_forward", data) - # moe_losses contains the loss of each layer - if gpc.config.get("model_type") == "INTERNLM": - output = self._call_engine(engine, data) - if gpc.config.get("model_type") == "INTERNLM_MoE": + if hasattr(gpc.config.model, "num_experts"): + # moe is used output, moe_losses = self._call_engine(engine, data) + else: + output = self._call_engine(engine, data) self._call_hooks("after_forward", output) self._call_hooks("post_helper_func", output, label) @@ -121,7 +121,7 @@ class NonPipelineScheduler(BaseScheduler): self._call_hooks("after_criterion", loss) moe_loss = ( sum(moe_losses) * gpc.config.loss.moe_loss_coeff - if gpc.config.get("model_type") == "INTERNLM_MoE" + if hasattr(gpc.config.model, "num_experts") else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) ) moe_loss /= scale_loss @@ -206,8 +206,8 @@ class NonPipelineScheduler(BaseScheduler): if not return_output_label: outputs, labels = None, None - # Compatible for old code - if gpc.config.get("model_type") == "INTERNLM": - return outputs, labels, loss - if gpc.config.get("model_type") == "INTERNLM_MoE": + # Compatible for non-moe + if hasattr(gpc.config.model, "num_experts"): return outputs, labels, loss, moe_loss + else: + return outputs, labels, loss diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 42e58e9..4ef8c86 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -275,11 +275,11 @@ class PipelineScheduler(BaseScheduler): data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) self._call_hooks("before_forward", data) - # moe_losses contains the loss of each layer in current stage - if gpc.config.get("model_type") == "INTERNLM": - output_obj = self._call_engine(engine.model, data) - if gpc.config.get("model_type") == "INTERNLM_MoE": + if hasattr(gpc.config.model, "num_experts"): + # moe is used output_obj, moe_losses = self._call_engine(engine.model, data) + else: + output_obj = self._call_engine(engine.model, data) self._call_hooks("after_forward", output_obj) if gpc.is_last_rank(ParallelMode.PIPELINE): @@ -297,7 +297,7 @@ class PipelineScheduler(BaseScheduler): moe_loss = ( sum(moe_losses) * gpc.config.loss.moe_loss_coeff - if gpc.config.get("model_type") == "INTERNLM_MoE" + if hasattr(gpc.config.model, "num_experts") else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) ) moe_loss /= self.num_microbatches @@ -673,11 +673,11 @@ class PipelineScheduler(BaseScheduler): engine, return_loss, return_output_label ) - # Compatible for old code - if gpc.config.get("model_type") == "INTERNLM": - return output, label, accum_loss - if gpc.config.get("model_type") == "INTERNLM_MoE": + # Compatible for non-moe + if hasattr(gpc.config.model, "num_experts"): return output, label, accum_loss, accum_moe_loss + else: + return output, label, accum_loss class InterleavedPipelineScheduler(PipelineScheduler): @@ -816,10 +816,10 @@ class InterleavedPipelineScheduler(PipelineScheduler): data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) self._call_hooks("before_forward", data) - if gpc.config.get("model_type") == "INTERNLM": - output_obj = self._call_engine(engine.model[chunk_id], data) - if gpc.config.get("model_type") == "INTERNLM_MoE": + if hasattr(gpc.config.model, "num_experts"): output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data) + else: + output_obj = self._call_engine(engine.model[chunk_id], data) # Convert output_obj to fp32 when last model chunk of last stage if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel): output_obj = engine.model[chunk_id].convert_to_fp32(output_obj) @@ -841,7 +841,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): moe_loss = ( sum(moe_losses) * gpc.config.loss.moe_loss_coeff - if gpc.config.get("model_type") == "INTERNLM_MoE" + if hasattr(gpc.config.model, "num_experts") else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) ) moe_loss /= self.num_microbatches @@ -1378,8 +1378,8 @@ class InterleavedPipelineScheduler(PipelineScheduler): self._clear_state() - # Compatible for old code - if gpc.config.get("model_type") == "INTERNLM": - return output, label, accum_loss - if gpc.config.get("model_type") == "INTERNLM_MoE": + # Compatible for non-moe + if hasattr(gpc.config.model, "num_experts"): return output, label, accum_loss, accum_moe_loss + else: + return output, label, accum_loss diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 2da6afd..660cc55 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -264,7 +264,7 @@ def args_sanity_check(): if "use_flash_attn" not in gpc.config.model: gpc.config.model._add_item("use_flash_attn", True) - if gpc.config.get("model_type") == "INTERNLM_MoE": + if "MoE" in gpc.config.get("model_type", "INTERNLM"): if "num_experts" not in model: model._add_item("num_experts", 1) if "moe_use_residual" not in model: diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 2f4aa67..a05a6b2 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -38,7 +38,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) # create new groups for fp32, norm, moe gate and moe expert new_groups = {} new_groups["fp32"] = {"name": "fp32", "params": []} - if gpc.config.get("model_type") == "INTERNLM_MoE" and gpc.config.model.num_experts > 1: + if gpc.config.model.get("num_experts", 0) > 1: # norm and gate are special group to force sync (when enable MoE). for key in ["gate", "norm"]: new_groups[key] = {"name": key, key: True, "params": []} @@ -57,11 +57,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) # first split the norm and gate groups, which are special case to force sync (when enable MoE), # then fp32 group and the moe group. for param in pgroup["params"]: - if ( - gpc.config.get("model_type") == "INTERNLM_MoE" - and gpc.config.model.num_experts > 1 - and is_norm_param(param) - ): + if gpc.config.model.get("num_experts", 0) > 1 and is_norm_param(param): new_groups["norm"]["params"].append(param) # gate param means MoE is enabled elif is_gate_param(param): diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 13d4468..6a55fa5 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -113,13 +113,13 @@ def evaluate_on_val_dls( tensor_shape=tensor_shape, metric_hook_list=[val_sche_metric_hook], ): - # Compatible for old code - if gpc.config.get("model_type") == "INTERNLM": - _, _, loss = trainer.execute_schedule( + # Compatible for non-moe + if hasattr(gpc.config.model, "num_experts"): + _, _, loss, moe_loss = trainer.execute_schedule( batch, forward_only=True, return_loss=True, return_output_label=False ) - elif gpc.config.get("model_type") == "INTERNLM_MoE": - _, _, loss, moe_loss = trainer.execute_schedule( + else: + _, _, loss = trainer.execute_schedule( batch, forward_only=True, return_loss=True, return_output_label=False ) else: @@ -133,12 +133,12 @@ def evaluate_on_val_dls( grad_accum_batch_size=grad_accum_batch_size, metric_hook_list=[val_sche_metric_hook], ): - if gpc.config.get("model_type") == "INTERNLM": - _, _, loss = trainer.execute_schedule( + if hasattr(gpc.config.model, "num_experts"): + _, _, loss, moe_loss = trainer.execute_schedule( batch, forward_only=True, return_loss=True, return_output_label=False ) - elif gpc.config.get("model_type") == "INTERNLM_MoE": - _, _, loss, moe_loss = trainer.execute_schedule( + else: + _, _, loss = trainer.execute_schedule( batch, forward_only=True, return_loss=True, return_output_label=False ) if verbose: diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index caba6a9..a30cfba 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -186,14 +186,14 @@ def train( # do forward and backward timer("fwd-bwd").start() - # Compatible for old code + # Compatible for non-moe moe_loss = None - if gpc.config.get("model_type") == "INTERNLM": - _, _, loss = trainer.execute_schedule( + if hasattr(gpc.config.model, "num_experts"): + _, _, loss, moe_loss = trainer.execute_schedule( batch, forward_only=False, return_loss=True, return_output_label=False ) - elif gpc.config.get("model_type") == "INTERNLM_MoE": - _, _, loss, moe_loss = trainer.execute_schedule( + else: + _, _, loss = trainer.execute_schedule( batch, forward_only=False, return_loss=True, return_output_label=False ) if gpc.is_rank_for_log(): diff --git a/train.py b/train.py index 7a30f6c..139bac1 100644 --- a/train.py +++ b/train.py @@ -220,15 +220,15 @@ def main(args): timer("fwd-bwd").start() moe_loss = None - if gpc.config.get("model_type") == "INTERNLM": - _, _, loss = trainer.execute_schedule( + if hasattr(gpc.config.model, "num_experts"): + _, _, loss, moe_loss = trainer.execute_schedule( batch, forward_only=False, return_loss=True, return_output_label=False, ) - if gpc.config.get("model_type") == "INTERNLM_MoE": - _, _, loss, moe_loss = trainer.execute_schedule( + else: + _, _, loss = trainer.execute_schedule( batch, forward_only=False, return_loss=True,