diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 9cc7a7d..09be064 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -343,7 +343,8 @@ class Initializer_Zero1(ProcessGroupInitializer): ranks_in_group = ranks return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode - + + class Initializer_Expert(ProcessGroupInitializer): """A ProcessGroupInitializer for zero-1 parallelism. diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index e2084d5..5a6749e 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -116,7 +116,7 @@ class NonPipelineScheduler(BaseScheduler): self._call_hooks("after_criterion", loss) moe_loss = sum(moe_losses) * moe_loss_coeff loss += moe_loss - loss /= scale_loss ## TODO: check whether mos_loss should be scaled + loss /= scale_loss # TODO: check whether mos_loss should be scaled # backward if not forward_only: diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 41ab97c..1cdb8f7 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -9,6 +9,7 @@ from torch.optim import Optimizer from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc +from internlm.model.moe import is_moe_param from internlm.solver.optimizer.store import ( BucketStore, GradientStore, @@ -25,7 +26,6 @@ from internlm.solver.optimizer.utils import ( split_half_float_double, sync_param, ) -from internlm.model.moe import is_moe_param from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer @@ -537,7 +537,7 @@ class HybridZeroOptimizer(BaseOptimizer): norm_groups.append(norm_group) loss_scale = float(self.loss_scale.item()) # backup - if not gpc.config.model.dtype is torch.float32: + if gpc.config.model.dtype is not torch.float32: self.grad_scaler.update(found_inf) # update loss scale if overflow occurs if found_inf: @@ -581,7 +581,7 @@ class HybridZeroOptimizer(BaseOptimizer): global_norm = sum(norm_groups) ** 0.5 # the following operations are performed only on the rank to which parameters are assigned. - if not gpc.config.model.dtype is torch.float32: + if gpc.config.model.dtype is not torch.float32: if len(single_grad_partition_groups) != 0: self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)