mirror of https://github.com/InternLM/InternLM
set default expert parallel size
parent
f5caa1c048
commit
e2b7a7fa89
|
@ -467,6 +467,14 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
if self.zero1_parallel_size <= 0:
|
||||
self.zero1_parallel_size = self.data_parallel_size
|
||||
|
||||
# if not set expert_parallel_size in parallel config
|
||||
if self.expert_parallel_size <= 0:
|
||||
# by default, expert_parallel_size equals to data_parallel_size, but if the number of experts is smaller
|
||||
# than data_parallel_size, set expert_parallel_size to be the number of experts to make sure each device
|
||||
# has one expert.
|
||||
self.expert_parallel_size = min(self.data_parallel_size, self.config.model.get("num_experts", 1))
|
||||
logger.warning(f"not set expert parallel size, set it as {self.expert_parallel_size}")
|
||||
|
||||
self.check_sanity()
|
||||
|
||||
initializer_args = [
|
||||
|
|
|
@ -74,7 +74,7 @@ def args_sanity_check():
|
|||
gpc.config.parallel._add_item("tensor", 1)
|
||||
|
||||
if "expert" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("expert", 1)
|
||||
gpc.config.parallel._add_item("expert", -1)
|
||||
|
||||
# processing the data config in gpc
|
||||
data = gpc.config.data
|
||||
|
|
|
@ -37,7 +37,7 @@ from internlm.utils.logger import get_logger
|
|||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.parallel import (
|
||||
is_no_pp_or_last_stage,
|
||||
sync_model_param_with_ep,
|
||||
sync_model_param,
|
||||
sync_model_param_within_tp,
|
||||
)
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
@ -80,7 +80,7 @@ def initialize_model():
|
|||
# This sync is very important, cause the model weights kept in optimizer are copied
|
||||
# from the origin parameters in the memory, so we should make sure the dp sync
|
||||
# does not influence the model weights in optimizer be different with the origin parameters.
|
||||
sync_model_param_with_ep(model)
|
||||
sync_model_param(model)
|
||||
|
||||
# This function is needed to make sure parameters that are not splitted by tensor parallelism are
|
||||
# the same across tensor parallelism.
|
||||
|
|
|
@ -12,45 +12,24 @@ def is_model_parallel_parameter(p):
|
|||
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
|
||||
|
||||
|
||||
def sync_model_param(model, parallel_mode):
|
||||
r"""Make sure data parameters are consistent during Data Parallel Mode.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
|
||||
parallel_mode (:class:`internlm.core.context.ParallelMode`): Parallel mode to be checked.
|
||||
"""
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||
for param in model.parameters():
|
||||
if is_moe_param(param):
|
||||
# TODO: moe expert param need to sync in expert data parallel group
|
||||
# now we do not support expert data parallel
|
||||
pass
|
||||
else:
|
||||
ranks = gpc.get_ranks_in_group(parallel_mode)
|
||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||
|
||||
|
||||
def sync_tensor(tensor, parallel_mode):
|
||||
r"""Make sure data tensor(parameters) are consistent during Data and Expert Parallel Mode.
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): A parameters you check the consistency.
|
||||
parallel_mode (:class:`internlm.core.context.ParallelMode`): Parallel mode to be checked.
|
||||
"""
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||
ranks = gpc.get_ranks_in_group(parallel_mode)
|
||||
dist.broadcast(tensor, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||
|
||||
|
||||
# TODO: will be used in expert data parallel, may can also used in sync_model_param_within_tp
|
||||
def sync_model_param_with_ep(model):
|
||||
def sync_model_param(model):
|
||||
r"""Make sure data parameters are consistent during Data Parallel Mode.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
|
||||
"""
|
||||
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
||||
sync_moe_param = (
|
||||
gpc.is_initialized(ParallelMode.EXPERT_DATA) and gpc.get_world_size(ParallelMode.EXPERT_DATA) > 1
|
||||
)
|
||||
for param in model.parameters():
|
||||
if is_moe_param(param):
|
||||
if sync_moe_param and is_moe_param(param):
|
||||
sync_tensor(param, ParallelMode.EXPERT_DATA)
|
||||
else:
|
||||
sync_tensor(param, ParallelMode.DATA)
|
||||
|
|
Loading…
Reference in New Issue