mirror of https://github.com/InternLM/InternLM
remove moe_loss_coeff parameter passing
parent
e498f9262e
commit
2ad5f512b5
|
@ -7,6 +7,7 @@ from typing import Any, Callable, Iterable, List, Optional
|
|||
|
||||
import torch
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.utils.common import conditional_context
|
||||
|
||||
|
@ -88,7 +89,6 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
scale_loss: int = 1,
|
||||
moe_loss_coeff: float = 0.01,
|
||||
):
|
||||
"""Trains one batch of data.
|
||||
|
||||
|
@ -115,7 +115,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
self._call_hooks("before_criterion", output, label)
|
||||
loss = self._call_engine_criterion(engine, output, label)
|
||||
self._call_hooks("after_criterion", loss)
|
||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||
moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
moe_loss /= scale_loss
|
||||
loss /= scale_loss
|
||||
loss += moe_loss
|
||||
|
@ -138,7 +138,6 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
moe_loss_coeff: float = 0.01,
|
||||
):
|
||||
"""The process function that loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
|
@ -184,7 +183,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
_data, _label = self._load_accum_batch(data, label)
|
||||
|
||||
_output, _loss, _moe_loss = self._train_one_batch(
|
||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff
|
||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
||||
)
|
||||
|
||||
if return_loss:
|
||||
|
|
|
@ -256,7 +256,6 @@ class PipelineScheduler(BaseScheduler):
|
|||
return_output_label=True,
|
||||
accum_loss=None,
|
||||
accum_moe_loss=None,
|
||||
moe_loss_coeff=1.0,
|
||||
):
|
||||
"""
|
||||
Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
|
@ -295,7 +294,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
accum_loss.add_(loss_reduced.detach())
|
||||
output_obj = loss_reduced
|
||||
|
||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||
moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
moe_loss /= self.num_microbatches
|
||||
accum_moe_loss.add_(moe_loss.detach())
|
||||
|
||||
|
@ -364,7 +363,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
return input_obj_grad
|
||||
|
||||
def _forward_only_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff=1.0):
|
||||
def _forward_only_step(self, engine, return_loss=True, return_output_label=True):
|
||||
"""
|
||||
This function performs forward only computation process. The scheduling of microbatches is similar to the
|
||||
warmup phase, where each microbatch first receives the forward input from the previous stage, then performs
|
||||
|
@ -419,7 +418,6 @@ class PipelineScheduler(BaseScheduler):
|
|||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
accum_moe_loss=accum_moe_loss,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
|
@ -437,7 +435,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
return output, label, accum_loss, accum_moe_loss
|
||||
|
||||
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff=1.0):
|
||||
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True):
|
||||
"""
|
||||
This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner.
|
||||
It consists of three stages: warmup, 1F1B, and cooldown.
|
||||
|
@ -519,7 +517,6 @@ class PipelineScheduler(BaseScheduler):
|
|||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
accum_moe_loss=accum_moe_loss,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
|
@ -540,7 +537,6 @@ class PipelineScheduler(BaseScheduler):
|
|||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
moe_losses.append(moe_loss)
|
||||
|
||||
# Before running 1F1B, need to receive first forward tensor.
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
# receive this tensor here.
|
||||
|
@ -566,7 +562,6 @@ class PipelineScheduler(BaseScheduler):
|
|||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
accum_moe_loss=accum_moe_loss,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
|
@ -632,8 +627,6 @@ class PipelineScheduler(BaseScheduler):
|
|||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}")
|
||||
|
||||
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
||||
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
|
||||
|
@ -642,9 +635,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
return output, label, accum_loss, accum_moe_loss
|
||||
|
||||
def forward_backward_step(
|
||||
self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0
|
||||
):
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
|
||||
|
@ -667,9 +658,9 @@ class PipelineScheduler(BaseScheduler):
|
|||
self.load_batch(engine, data_iter)
|
||||
|
||||
if forward_only:
|
||||
return self._forward_only_step(engine, return_loss, return_output_label, moe_loss_coeff)
|
||||
return self._forward_only_step(engine, return_loss, return_output_label)
|
||||
else:
|
||||
return self._forward_backward_step(engine, return_loss, return_output_label, moe_loss_coeff)
|
||||
return self._forward_backward_step(engine, return_loss, return_output_label)
|
||||
|
||||
|
||||
class InterleavedPipelineScheduler(PipelineScheduler):
|
||||
|
@ -786,7 +777,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||
return move_to_device(micro_batch_data)
|
||||
|
||||
def _forward_step(self, engine, chunk_id, moe_loss_coeff=1.0):
|
||||
def _forward_step(self, engine, chunk_id):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
|
@ -828,7 +819,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
self._accum_loss.add_(loss_reduced.detach())
|
||||
output_obj = loss_reduced
|
||||
|
||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||
moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
moe_loss /= self.num_microbatches
|
||||
|
||||
if self._accum_moe_loss is not None:
|
||||
|
@ -895,7 +886,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
num_warmup_microsteps: int,
|
||||
receive_extra_backward: bool = False,
|
||||
forward_only: bool = False,
|
||||
moe_loss_coeff: float = 1.0,
|
||||
) -> None:
|
||||
"""
|
||||
Run the warm-up loop and prepare data for the 1F1B stage.
|
||||
|
@ -933,7 +923,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
for k in range(num_warmup_microsteps):
|
||||
chunk_id = self._get_chunk_by_microbatch(k)
|
||||
|
||||
output_obj = self._forward_step(engine, chunk_id, moe_loss_coeff)
|
||||
output_obj = self._forward_step(engine, chunk_id)
|
||||
|
||||
if forward_only:
|
||||
# when forward-only, no need to save tensors for a backward pass
|
||||
|
@ -1015,7 +1005,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
num_warmup_microsteps: int,
|
||||
num_1f1b_micropairs: int,
|
||||
all_warmup_microsteps: bool = False,
|
||||
moe_loss_coeff: float = 1.0,
|
||||
) -> None:
|
||||
"""
|
||||
Run the 1F1B loop with overlap.
|
||||
|
@ -1045,7 +1034,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
|
||||
|
||||
# 1. Forward pass.
|
||||
output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff)
|
||||
output_obj = self._forward_step(engine, forward_chunk_id)
|
||||
|
||||
# 2. Check if the backward input is ready.
|
||||
if backward_async_communicator is not None:
|
||||
|
@ -1130,7 +1119,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
num_warmup_microsteps: int,
|
||||
num_1f1b_micropairs: int,
|
||||
all_warmup_microsteps: bool = False,
|
||||
moe_loss_coeff: float = 1.0,
|
||||
) -> None:
|
||||
"""
|
||||
Run the 1F1B loop without overlap.
|
||||
|
@ -1152,7 +1140,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
# Forward pass.
|
||||
forward_microstep_id = k + num_warmup_microsteps
|
||||
forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
|
||||
output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff)
|
||||
output_obj = self._forward_step(engine, forward_chunk_id)
|
||||
|
||||
# Backward pass.
|
||||
backward_microstep_id = k
|
||||
|
@ -1257,7 +1245,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
)
|
||||
)
|
||||
|
||||
def _forward_only_step(self, engine: Engine, moe_loss_coeff: float = 1.0):
|
||||
def _forward_only_step(self, engine: Engine):
|
||||
num_microsteps = self.num_microbatches * self._num_chunks
|
||||
num_warmup_microsteps = num_microsteps
|
||||
|
||||
|
@ -1267,10 +1255,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
num_warmup_microsteps,
|
||||
receive_extra_backward=False,
|
||||
forward_only=True,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
def _forward_backward_step(self, engine: Engine, moe_loss_coeff: float = 1.0):
|
||||
def _forward_backward_step(self, engine: Engine):
|
||||
# Compute number of warmup and remaining microbatches.
|
||||
all_warmup_microsteps = False
|
||||
num_microsteps = self.num_microbatches * self._num_chunks
|
||||
|
@ -1304,7 +1291,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
num_microsteps,
|
||||
num_warmup_steps,
|
||||
receive_extra_backward=receive_extra_backward,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
# 2. 1F1B
|
||||
|
@ -1313,15 +1299,12 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
num_warmup_steps,
|
||||
num_1f1b_micropairs=num_1f1b_micropairs,
|
||||
all_warmup_microsteps=all_warmup_microsteps,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
# 3. Cooldown
|
||||
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
|
||||
|
||||
def forward_backward_step(
|
||||
self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0
|
||||
):
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
||||
communication between pipeline stages as needed.
|
||||
|
||||
|
@ -1353,9 +1336,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
self._return_tensors = []
|
||||
|
||||
if forward_only:
|
||||
self._forward_only_step(engine, moe_loss_coeff)
|
||||
self._forward_only_step(engine)
|
||||
else:
|
||||
self._forward_backward_step(engine, moe_loss_coeff)
|
||||
self._forward_backward_step(engine)
|
||||
|
||||
if return_output_label and len(self._return_tensors) > 0:
|
||||
output, label = pack_return_tensors(self._return_tensors)
|
||||
|
|
|
@ -269,7 +269,7 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
|||
if "use_flash_attn" not in gpc.config.model:
|
||||
gpc.config.model._add_item("use_flash_attn", True)
|
||||
if "num_experts" not in model:
|
||||
model._add_item("num_experts", 0)
|
||||
model._add_item("num_experts", 1)
|
||||
|
||||
# process the parallel config
|
||||
if "sequence_parallel" not in gpc.config.parallel:
|
||||
|
|
|
@ -133,7 +133,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
self.moe_use_rts = moe_use_rts
|
||||
self.moe_use_residual = moe_use_residual
|
||||
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
|
||||
if num_experts == 0: # dense, not MoE
|
||||
if num_experts <= 1: # dense, not MoE
|
||||
if use_swiglu:
|
||||
self.mlp = FeedForward(
|
||||
hidden_size,
|
||||
|
|
|
@ -100,7 +100,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
|||
|
||||
adam_cfg = gpc.config.adam
|
||||
# split the moe parameters into different groups
|
||||
if gpc.config.model.num_experts != 0:
|
||||
if gpc.config.model.num_experts > 1:
|
||||
params = create_moe_param_groups(model, adam_cfg.weight_decay)
|
||||
else:
|
||||
params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}]
|
||||
|
@ -368,6 +368,12 @@ def record_current_batch_training_metrics(
|
|||
|
||||
tflops = get_tflops_func((time.time() - start_time))
|
||||
|
||||
# change grad_norm list to dict for calling writer's add_scalars
|
||||
grad_norm_dict = {}
|
||||
assert isinstance(grad_norm, list)
|
||||
for inx, norm in enumerate(grad_norm):
|
||||
grad_norm_dict[f"grad_norm_{inx}"] = norm
|
||||
|
||||
infos = {
|
||||
"tflops": tflops,
|
||||
"step": batch_count,
|
||||
|
@ -376,7 +382,7 @@ def record_current_batch_training_metrics(
|
|||
"tgs (tokens/gpu/second)": tk_per_gpu,
|
||||
"lr": lr,
|
||||
"loss_scale": scaler,
|
||||
"grad_norm": grad_norm, # TODO: not scalar
|
||||
"grad_norm": grad_norm_dict,
|
||||
}
|
||||
|
||||
infos["micro_num"] = len(batch[1])
|
||||
|
@ -397,24 +403,31 @@ def record_current_batch_training_metrics(
|
|||
line = ""
|
||||
for key, value in infos.items():
|
||||
line += f"{key}={value} "
|
||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||
if isinstance(value, dict):
|
||||
writer.add_scalars(key=key, value=value, step=train_state.step_count)
|
||||
else:
|
||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||
|
||||
if update_panel:
|
||||
# metrics shown with dashboard panels
|
||||
panel_metrics = {
|
||||
"step": batch_count,
|
||||
"lr": lr,
|
||||
"num_consumed_tokens": train_state.num_consumed_tokens,
|
||||
"loss": loss.item(),
|
||||
"flops": tflops,
|
||||
"tgs": tk_per_gpu,
|
||||
"acc": acc_perplex["acc"],
|
||||
"perplexity": acc_perplex["perplexity"],
|
||||
"fwd_bwd_time": fwd_bwd_time,
|
||||
}
|
||||
for norm_key, norm_value in grad_norm_dict.items():
|
||||
panel_metrics[norm_key] = norm_value
|
||||
|
||||
logger.info(
|
||||
line,
|
||||
extra={
|
||||
"step": batch_count,
|
||||
"lr": lr,
|
||||
"num_consumed_tokens": train_state.num_consumed_tokens,
|
||||
"grad_norm": grad_norm,
|
||||
"loss": loss.item(),
|
||||
"moe_loss": moe_loss.item(),
|
||||
"flops": tflops,
|
||||
"tgs": tk_per_gpu,
|
||||
"acc": acc_perplex["acc"],
|
||||
"perplexity": acc_perplex["perplexity"],
|
||||
"fwd_bwd_time": fwd_bwd_time,
|
||||
},
|
||||
"{line}",
|
||||
line=line,
|
||||
extra=panel_metrics,
|
||||
)
|
||||
else:
|
||||
logger.info(line)
|
||||
|
|
|
@ -134,6 +134,14 @@ class Writer:
|
|||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
def add_scalars(self, key, value, step):
|
||||
try:
|
||||
assert isinstance(value, dict)
|
||||
if self.enable_tb and self.tb_writer is not None:
|
||||
self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
def add_text(self, key, value, step):
|
||||
try:
|
||||
if self.enable_tb and self.tb_writer is not None:
|
||||
|
|
4
train.py
4
train.py
|
@ -6,7 +6,6 @@ import time
|
|||
import traceback
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
@ -227,7 +226,6 @@ def main(args):
|
|||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=False,
|
||||
moe_loss_coeff=gpc.config.loss.moe_loss_coeff,
|
||||
)
|
||||
timer("fwd-bwd").stop()
|
||||
|
||||
|
@ -262,7 +260,7 @@ def main(args):
|
|||
start_time=start_time,
|
||||
loss=loss,
|
||||
moe_loss=moe_loss,
|
||||
grad_norm=np.linalg.norm(grad_norm_groups),
|
||||
grad_norm=grad_norm_groups,
|
||||
metric=metric,
|
||||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue