modified: internlm/core/context/process_group_initializer.py

modified:   internlm/core/scheduler/no_pipeline_scheduler.py
	modified:   internlm/solver/optimizer/hybrid_zero_optim.py
pull/375/head
Wenwen Qu 2023-08-08 15:59:12 +08:00
parent 2a52452ed2
commit 84476833f3
3 changed files with 6 additions and 5 deletions

View File

@ -343,7 +343,8 @@ class Initializer_Zero1(ProcessGroupInitializer):
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_Expert(ProcessGroupInitializer):
"""A ProcessGroupInitializer for zero-1 parallelism.

View File

@ -116,7 +116,7 @@ class NonPipelineScheduler(BaseScheduler):
self._call_hooks("after_criterion", loss)
moe_loss = sum(moe_losses) * moe_loss_coeff
loss += moe_loss
loss /= scale_loss ## TODO: check whether mos_loss should be scaled
loss /= scale_loss # TODO: check whether mos_loss should be scaled
# backward
if not forward_only:

View File

@ -9,6 +9,7 @@ from torch.optim import Optimizer
from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.moe import is_moe_param
from internlm.solver.optimizer.store import (
BucketStore,
GradientStore,
@ -25,7 +26,6 @@ from internlm.solver.optimizer.utils import (
split_half_float_double,
sync_param,
)
from internlm.model.moe import is_moe_param
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
@ -537,7 +537,7 @@ class HybridZeroOptimizer(BaseOptimizer):
norm_groups.append(norm_group)
loss_scale = float(self.loss_scale.item()) # backup
if not gpc.config.model.dtype is torch.float32:
if gpc.config.model.dtype is not torch.float32:
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs
if found_inf:
@ -581,7 +581,7 @@ class HybridZeroOptimizer(BaseOptimizer):
global_norm = sum(norm_groups) ** 0.5
# the following operations are performed only on the rank to which parameters are assigned.
if not gpc.config.model.dtype is torch.float32:
if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)