mirror of https://github.com/InternLM/InternLM
change condition for compatibility
parent
591b4edb1d
commit
f96764868d
|
@ -106,11 +106,11 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
# forward
|
# forward
|
||||||
with conditional_context(torch.no_grad(), enable=forward_only):
|
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||||
self._call_hooks("before_forward", data)
|
self._call_hooks("before_forward", data)
|
||||||
# moe_losses contains the loss of each layer
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
if gpc.config.get("model_type") == "INTERNLM":
|
# moe is used
|
||||||
output = self._call_engine(engine, data)
|
|
||||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
|
||||||
output, moe_losses = self._call_engine(engine, data)
|
output, moe_losses = self._call_engine(engine, data)
|
||||||
|
else:
|
||||||
|
output = self._call_engine(engine, data)
|
||||||
self._call_hooks("after_forward", output)
|
self._call_hooks("after_forward", output)
|
||||||
|
|
||||||
self._call_hooks("post_helper_func", output, label)
|
self._call_hooks("post_helper_func", output, label)
|
||||||
|
@ -121,7 +121,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
self._call_hooks("after_criterion", loss)
|
self._call_hooks("after_criterion", loss)
|
||||||
moe_loss = (
|
moe_loss = (
|
||||||
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||||
if gpc.config.get("model_type") == "INTERNLM_MoE"
|
if hasattr(gpc.config.model, "num_experts")
|
||||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||||
)
|
)
|
||||||
moe_loss /= scale_loss
|
moe_loss /= scale_loss
|
||||||
|
@ -206,8 +206,8 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
if not return_output_label:
|
if not return_output_label:
|
||||||
outputs, labels = None, None
|
outputs, labels = None, None
|
||||||
|
|
||||||
# Compatible for old code
|
# Compatible for non-moe
|
||||||
if gpc.config.get("model_type") == "INTERNLM":
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
return outputs, labels, loss
|
|
||||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
|
||||||
return outputs, labels, loss, moe_loss
|
return outputs, labels, loss, moe_loss
|
||||||
|
else:
|
||||||
|
return outputs, labels, loss
|
||||||
|
|
|
@ -275,11 +275,11 @@ class PipelineScheduler(BaseScheduler):
|
||||||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
|
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
|
||||||
|
|
||||||
self._call_hooks("before_forward", data)
|
self._call_hooks("before_forward", data)
|
||||||
# moe_losses contains the loss of each layer in current stage
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
if gpc.config.get("model_type") == "INTERNLM":
|
# moe is used
|
||||||
output_obj = self._call_engine(engine.model, data)
|
|
||||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
|
||||||
output_obj, moe_losses = self._call_engine(engine.model, data)
|
output_obj, moe_losses = self._call_engine(engine.model, data)
|
||||||
|
else:
|
||||||
|
output_obj = self._call_engine(engine.model, data)
|
||||||
self._call_hooks("after_forward", output_obj)
|
self._call_hooks("after_forward", output_obj)
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
|
@ -297,7 +297,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
|
|
||||||
moe_loss = (
|
moe_loss = (
|
||||||
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||||
if gpc.config.get("model_type") == "INTERNLM_MoE"
|
if hasattr(gpc.config.model, "num_experts")
|
||||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||||
)
|
)
|
||||||
moe_loss /= self.num_microbatches
|
moe_loss /= self.num_microbatches
|
||||||
|
@ -673,11 +673,11 @@ class PipelineScheduler(BaseScheduler):
|
||||||
engine, return_loss, return_output_label
|
engine, return_loss, return_output_label
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compatible for old code
|
# Compatible for non-moe
|
||||||
if gpc.config.get("model_type") == "INTERNLM":
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
return output, label, accum_loss
|
|
||||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
|
||||||
return output, label, accum_loss, accum_moe_loss
|
return output, label, accum_loss, accum_moe_loss
|
||||||
|
else:
|
||||||
|
return output, label, accum_loss
|
||||||
|
|
||||||
|
|
||||||
class InterleavedPipelineScheduler(PipelineScheduler):
|
class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
|
@ -816,10 +816,10 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
|
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
|
||||||
|
|
||||||
self._call_hooks("before_forward", data)
|
self._call_hooks("before_forward", data)
|
||||||
if gpc.config.get("model_type") == "INTERNLM":
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
output_obj = self._call_engine(engine.model[chunk_id], data)
|
|
||||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
|
||||||
output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data)
|
output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data)
|
||||||
|
else:
|
||||||
|
output_obj = self._call_engine(engine.model[chunk_id], data)
|
||||||
# Convert output_obj to fp32 when last model chunk of last stage
|
# Convert output_obj to fp32 when last model chunk of last stage
|
||||||
if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
|
if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
|
||||||
output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
|
output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
|
||||||
|
@ -841,7 +841,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
|
|
||||||
moe_loss = (
|
moe_loss = (
|
||||||
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||||
if gpc.config.get("model_type") == "INTERNLM_MoE"
|
if hasattr(gpc.config.model, "num_experts")
|
||||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||||
)
|
)
|
||||||
moe_loss /= self.num_microbatches
|
moe_loss /= self.num_microbatches
|
||||||
|
@ -1378,8 +1378,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
|
|
||||||
self._clear_state()
|
self._clear_state()
|
||||||
|
|
||||||
# Compatible for old code
|
# Compatible for non-moe
|
||||||
if gpc.config.get("model_type") == "INTERNLM":
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
return output, label, accum_loss
|
|
||||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
|
||||||
return output, label, accum_loss, accum_moe_loss
|
return output, label, accum_loss, accum_moe_loss
|
||||||
|
else:
|
||||||
|
return output, label, accum_loss
|
||||||
|
|
|
@ -264,7 +264,7 @@ def args_sanity_check():
|
||||||
if "use_flash_attn" not in gpc.config.model:
|
if "use_flash_attn" not in gpc.config.model:
|
||||||
gpc.config.model._add_item("use_flash_attn", True)
|
gpc.config.model._add_item("use_flash_attn", True)
|
||||||
|
|
||||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
if "MoE" in gpc.config.get("model_type", "INTERNLM"):
|
||||||
if "num_experts" not in model:
|
if "num_experts" not in model:
|
||||||
model._add_item("num_experts", 1)
|
model._add_item("num_experts", 1)
|
||||||
if "moe_use_residual" not in model:
|
if "moe_use_residual" not in model:
|
||||||
|
|
|
@ -38,7 +38,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
# create new groups for fp32, norm, moe gate and moe expert
|
# create new groups for fp32, norm, moe gate and moe expert
|
||||||
new_groups = {}
|
new_groups = {}
|
||||||
new_groups["fp32"] = {"name": "fp32", "params": []}
|
new_groups["fp32"] = {"name": "fp32", "params": []}
|
||||||
if gpc.config.get("model_type") == "INTERNLM_MoE" and gpc.config.model.num_experts > 1:
|
if gpc.config.model.get("num_experts", 0) > 1:
|
||||||
# norm and gate are special group to force sync (when enable MoE).
|
# norm and gate are special group to force sync (when enable MoE).
|
||||||
for key in ["gate", "norm"]:
|
for key in ["gate", "norm"]:
|
||||||
new_groups[key] = {"name": key, key: True, "params": []}
|
new_groups[key] = {"name": key, key: True, "params": []}
|
||||||
|
@ -57,11 +57,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
# first split the norm and gate groups, which are special case to force sync (when enable MoE),
|
# first split the norm and gate groups, which are special case to force sync (when enable MoE),
|
||||||
# then fp32 group and the moe group.
|
# then fp32 group and the moe group.
|
||||||
for param in pgroup["params"]:
|
for param in pgroup["params"]:
|
||||||
if (
|
if gpc.config.model.get("num_experts", 0) > 1 and is_norm_param(param):
|
||||||
gpc.config.get("model_type") == "INTERNLM_MoE"
|
|
||||||
and gpc.config.model.num_experts > 1
|
|
||||||
and is_norm_param(param)
|
|
||||||
):
|
|
||||||
new_groups["norm"]["params"].append(param)
|
new_groups["norm"]["params"].append(param)
|
||||||
# gate param means MoE is enabled
|
# gate param means MoE is enabled
|
||||||
elif is_gate_param(param):
|
elif is_gate_param(param):
|
||||||
|
|
|
@ -113,13 +113,13 @@ def evaluate_on_val_dls(
|
||||||
tensor_shape=tensor_shape,
|
tensor_shape=tensor_shape,
|
||||||
metric_hook_list=[val_sche_metric_hook],
|
metric_hook_list=[val_sche_metric_hook],
|
||||||
):
|
):
|
||||||
# Compatible for old code
|
# Compatible for non-moe
|
||||||
if gpc.config.get("model_type") == "INTERNLM":
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
_, _, loss = trainer.execute_schedule(
|
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||||
)
|
)
|
||||||
elif gpc.config.get("model_type") == "INTERNLM_MoE":
|
else:
|
||||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
_, _, loss = trainer.execute_schedule(
|
||||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -133,12 +133,12 @@ def evaluate_on_val_dls(
|
||||||
grad_accum_batch_size=grad_accum_batch_size,
|
grad_accum_batch_size=grad_accum_batch_size,
|
||||||
metric_hook_list=[val_sche_metric_hook],
|
metric_hook_list=[val_sche_metric_hook],
|
||||||
):
|
):
|
||||||
if gpc.config.get("model_type") == "INTERNLM":
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
_, _, loss = trainer.execute_schedule(
|
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||||
)
|
)
|
||||||
elif gpc.config.get("model_type") == "INTERNLM_MoE":
|
else:
|
||||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
_, _, loss = trainer.execute_schedule(
|
||||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||||
)
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
|
|
|
@ -186,14 +186,14 @@ def train(
|
||||||
# do forward and backward
|
# do forward and backward
|
||||||
timer("fwd-bwd").start()
|
timer("fwd-bwd").start()
|
||||||
|
|
||||||
# Compatible for old code
|
# Compatible for non-moe
|
||||||
moe_loss = None
|
moe_loss = None
|
||||||
if gpc.config.get("model_type") == "INTERNLM":
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
_, _, loss = trainer.execute_schedule(
|
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||||
batch, forward_only=False, return_loss=True, return_output_label=False
|
batch, forward_only=False, return_loss=True, return_output_label=False
|
||||||
)
|
)
|
||||||
elif gpc.config.get("model_type") == "INTERNLM_MoE":
|
else:
|
||||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
_, _, loss = trainer.execute_schedule(
|
||||||
batch, forward_only=False, return_loss=True, return_output_label=False
|
batch, forward_only=False, return_loss=True, return_output_label=False
|
||||||
)
|
)
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
|
|
8
train.py
8
train.py
|
@ -220,15 +220,15 @@ def main(args):
|
||||||
timer("fwd-bwd").start()
|
timer("fwd-bwd").start()
|
||||||
|
|
||||||
moe_loss = None
|
moe_loss = None
|
||||||
if gpc.config.get("model_type") == "INTERNLM":
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
_, _, loss = trainer.execute_schedule(
|
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||||
batch,
|
batch,
|
||||||
forward_only=False,
|
forward_only=False,
|
||||||
return_loss=True,
|
return_loss=True,
|
||||||
return_output_label=False,
|
return_output_label=False,
|
||||||
)
|
)
|
||||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
else:
|
||||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
_, _, loss = trainer.execute_schedule(
|
||||||
batch,
|
batch,
|
||||||
forward_only=False,
|
forward_only=False,
|
||||||
return_loss=True,
|
return_loss=True,
|
||||||
|
|
Loading…
Reference in New Issue