diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 5a6749e..2e931ea 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -116,7 +116,7 @@ class NonPipelineScheduler(BaseScheduler): self._call_hooks("after_criterion", loss) moe_loss = sum(moe_losses) * moe_loss_coeff loss += moe_loss - loss /= scale_loss # TODO: check whether mos_loss should be scaled + loss /= scale_loss # backward if not forward_only: diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index db77315..1bf3499 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -10,6 +10,7 @@ from torch.optim import Optimizer from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc from internlm.model.moe import is_moe_param +from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, GradientStore, @@ -29,7 +30,6 @@ from internlm.solver.optimizer.utils import ( from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.monitor import send_alert_message from .utils import compute_norm @@ -89,6 +89,7 @@ class HybridZeroOptimizer(BaseOptimizer): overlap_broadcast=False, grad_scal_cfg: Config = None, zero_cfg: Config = None, + has_moe: bool = False, ): # DynamicGradScaler related args if gpc.config.model.dtype is torch.float32: @@ -109,6 +110,8 @@ class HybridZeroOptimizer(BaseOptimizer): super().__init__(optim=optimizer) + self.has_moe = has_moe + self._dtype = self.optim.param_groups[0]["params"][0].dtype self._cpu_offload = cpu_offload self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1) @@ -277,7 +280,9 @@ class HybridZeroOptimizer(BaseOptimizer): no_params_ranks.append(rank) if gpc.is_rank_for_log(): - logger.info(f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}") + logger.info( # pylint: disable=W1203 + f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}" + ) return params_per_rank, set(no_params_ranks) @@ -503,6 +508,20 @@ class HybridZeroOptimizer(BaseOptimizer): return self._step(closure=closure) + def _get_norm_with_moe_layers(self, norm_groups): + # all_groups_norm_old = all_groups_norm + # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce + pg = gpc.get_group(ParallelMode.DATA) + print(type(norm_groups)) + scaled_norm = norm_groups * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) + scaled_norm_tensor = torch.tensor( + scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float + ) + dist.all_reduce(scaled_norm_tensor, group=pg) + all_groups_norm = scaled_norm_tensor.item() + # print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") + return all_groups_norm + def _step(self, closure=None): assert closure is None, "closure is not supported by step()" @@ -582,6 +601,9 @@ class HybridZeroOptimizer(BaseOptimizer): if self._clip_grad_norm > 0: global_norm = sum(norm_groups) ** 0.5 + if self.has_moe: + global_norm = self._get_norm_with_moe_layers(global_norm) + # the following operations are performed only on the rank to which parameters are assigned. if gpc.config.model.dtype is not torch.float32: if len(single_grad_partition_groups) != 0: diff --git a/train.py b/train.py index 72f2820..5de592b 100644 --- a/train.py +++ b/train.py @@ -30,6 +30,7 @@ from internlm.data.packed_dataset import ( from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.model.loss import FlashGPTLMLoss from internlm.model.metrics import AccPerplex +from internlm.model.moe import has_moe_layers from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var from internlm.monitor.monitor import monitor_manager as mm from internlm.solver.beta2_scheduler import Beta2Scheduler @@ -247,7 +248,7 @@ def get_validation_data_loader(num_worker: int = 0): batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz if batch_size == 0 and gpc.is_rank_for_log(): - logger.info(f"skip validate {val_name}.") + logger.info(f"skip validate {val_name}.") # pylint: disable=W1203 continue val_dls[val_name] = get_dpsampler_dataloader( @@ -255,7 +256,7 @@ def get_validation_data_loader(num_worker: int = 0): ) # drop_last=True, otherwise it may cause problems in the last batch if gpc.is_rank_for_log(): - logger.info( + logger.info( # pylint: disable=W1203 f"load validation dataset {val_name} with valid batch size {str(batch_size)} and " f"samples {str(len(val_dls[val_name]))}." ) @@ -307,8 +308,12 @@ def initialize_optimizer(model: nn.Module): eps=adam_cfg.adam_eps, ) + has_moe = has_moe_layers(model) optimizer = HybridZeroOptimizer( - naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer + naive_optimizer, + grad_scal_cfg=gpc.config.grad_scaler, + zero_cfg=gpc.config.hybrid_zero_optimizer, + has_moe=has_moe, ) beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) @@ -472,19 +477,19 @@ def main(args): model_load_path = None if load_resume_ckpt_folder is not None: - logger.info( + logger.info( # pylint: disable=W1203 f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:" f"{socket.gethostname()}===========" ) model_load_path = load_resume_ckpt_folder elif load_model_only_folder is not None: - logger.info( + logger.info( # pylint: disable=W1203 f"===========SFT training from `{load_model_only_folder}` {current_time} on host:" f"{socket.gethostname()}===========" ) model_load_path = load_model_only_folder else: - logger.info( + logger.info( # pylint: disable=W1203 f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" @@ -580,7 +585,7 @@ def main(args): train_state.num_consumed_samples_in_epoch += len(batch[1]) if batch_skipper(batch_count): # skip this batch if gpc.is_rank_for_log(): - logger.info(f"Skip batch count:`{batch_count}`...") + logger.info(f"Skip batch count:`{batch_count}`...") # pylint: disable=W1203 timer("one-batch").stop() continue @@ -596,7 +601,13 @@ def main(args): # do forward and backward timer("fwd-bwd").start() - _, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False, moe_loss_coeff = gpc.config.loss.moe_loss_coeff) + _, _, loss = trainer.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + moe_loss_coeff=gpc.config.loss.moe_loss_coeff, + ) timer("fwd-bwd").stop() # update parameters, and returns (success_update, grad_norm) @@ -609,7 +620,7 @@ def main(args): else: train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case - logger.warning(f"Warning: skip parameter update at step {batch_count}.") + logger.warning(f"Warning: skip parameter update at step {batch_count}.") # pylint: disable=W1203 send_alert_message( address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}." ) @@ -667,7 +678,7 @@ if __name__ == "__main__": try: main(args) except Exception: - logger.error( + logger.error( # pylint: disable=W1203 f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}", exc_info=traceback.format_exc(), )