diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index cba39fb..7f3e415 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -463,12 +463,19 @@ class ParallelContext(metaclass=SingletonMeta): # the recommended nettest_parallel_size is 32 GPUs self.nettest_parallel_size = 32 - # TODO : data parallel size can be different with expert parallel size - self.expert_parallel_size = self.data_parallel_size - if self.zero1_parallel_size <= 0: self.zero1_parallel_size = self.data_parallel_size + assert ( + self.data_parallel_size % self.config.model.get("num_experts", 1) == 0 + or self.config.model.get("num_experts", 1) % self.data_parallel_size == 0 + ), "can not place the experts evenly" + + # by default, expert_parallel_size equals to data_parallel_size, but if the number of experts is smaller + # than data_parallel_size, set expert_parallel_size to be the number of experts to make sure each device + # has one expert. + self.expert_parallel_size = min(self.data_parallel_size, self.config.model.get("num_experts", 1)) + self.check_sanity() initializer_args = [ @@ -492,7 +499,7 @@ class ParallelContext(metaclass=SingletonMeta): if self.pipeline_parallel_size > 1: initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args)) if self.config.model.get("num_experts", 1) > 1: - initializers.append(pgroup_initializer.Initializer_Expert(*initializer_args)) + initializers.append(pgroup_initializer.Initializer_Expert_Data(*initializer_args)) for initializer in initializers: parallel_setting = initializer.init_dist_group() if isinstance(parallel_setting, list): diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index ccdb53c..7abca14 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -125,7 +125,8 @@ class HybridZeroOptimizer(BaseOptimizer): # it will not manage the tensors used by mixed precision training self._param_store = ParameterStore(ParallelMode.ZERO1) self._grad_store = GradientStore(ParallelMode.DATA) - self._bucket_store = BucketStore(ParallelMode.DATA) + self._non_moe_bucket_store = BucketStore(ParallelMode.DATA) + self._moe_bucket_store = BucketStore(ParallelMode.EXPERT_DATA) self._bucket_in_progress = [] # fp16 and fp32 params for mixed precision training @@ -321,7 +322,7 @@ class HybridZeroOptimizer(BaseOptimizer): param_group = self._fp16_param_groups[group_id] for param in param_group: # we should not reduce the param in moe - if param.requires_grad and not is_moe_param(param): + if param.requires_grad: reduce_rank = None def _define_and_attach(param, reduce_rank=None): @@ -353,8 +354,13 @@ class HybridZeroOptimizer(BaseOptimizer): # check if the bucket is full # if full, will reduce the grads already in the bucket # after reduction, the bucket will be empty - if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: - self._reduce_grads_stored_in_bucket(reduce_rank, last_bucket=False) + if is_moe_param(param): + current_bucket = self._moe_bucket_store + else: + current_bucket = self._non_moe_bucket_store + + if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: + self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False) # the param must not be reduced to ensure correctness is_param_reduced = self._param_store.is_param_reduced(param) @@ -368,19 +374,20 @@ class HybridZeroOptimizer(BaseOptimizer): # the param must have grad for reduction assert param.grad is not None, f"Parameter of size ({param.size()}) has None grad, cannot be reduced" - self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) - self._bucket_store.add_grad(param.grad, reduce_rank) - self._bucket_store.add_param(param, reduce_rank) + current_bucket.add_num_elements_in_bucket(param_size, reduce_rank) + current_bucket.add_grad(param.grad, reduce_rank) + current_bucket.add_param(param, reduce_rank) - def _reduce_grads_stored_in_bucket(self, reduce_rank=None, last_bucket=False): + def _reduce_grads_stored_in_bucket(self, current_bucket, reduce_rank=None, last_bucket=False): # reduce grads self._reduce_grads_by_rank( reduce_rank=reduce_rank, - grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), - bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank), + grads=current_bucket.get_grad(reduce_rank=reduce_rank), + bucket_size=current_bucket.num_elements_in_bucket(reduce_rank), + dp_parallel_mode=current_bucket.get_dp_parallel_mode(), ) - params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank) + params_in_bucket = current_bucket.get_param(reduce_rank=reduce_rank) for param in params_in_bucket: # the is_param_reduced flag should be False showing that @@ -402,9 +409,9 @@ class HybridZeroOptimizer(BaseOptimizer): else: self._param_store.add_previous_reduced_param(param) - self._bucket_store.reset_by_rank(reduce_rank) + current_bucket.reset_by_rank(reduce_rank) - def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size): + def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, dp_parallel_mode): grad_buckets_by_dtype = split_half_float_double(grads) next_bucket_list = [] # add parameters into bucket for reduction @@ -413,7 +420,7 @@ class HybridZeroOptimizer(BaseOptimizer): for tensor in tensor_list: param_bucket.add_to_bucket(tensor, allow_oversize=True) if not param_bucket.is_empty(): - self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) + self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank, dp_parallel_mode=dp_parallel_mode) next_bucket_list.append(param_bucket) # wait for the completion of previouce bucket list reduction, and do unflatten_and_copy() @@ -428,14 +435,14 @@ class HybridZeroOptimizer(BaseOptimizer): # after the completion of bucket list reduction, add new buckets into _bucket_in_progress self._bucket_in_progress = next_bucket_list.copy() - def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): + def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, dp_parallel_mode): # flatten the tensors and do allreduce bucket.flatten() bucket.commu_handle = reduce_tensor( tensor=bucket.get_flat_tensor(), dtype=None, dst_rank=reduce_rank, - parallel_mode=ParallelMode.DATA, + parallel_mode=dp_parallel_mode, ) # update the reduced tensor @@ -581,11 +588,12 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id in range(len(self._fp16_param_groups)): for param in self._fp16_param_groups[group_id]: # we should not reduce the param in moe - if param.grad is not None and not is_moe_param(param): + if param.grad is not None: self._store_and_try_reduce_grads_by_bucket(param) # we need to reduce the gradients left in the communication bucket - self._reduce_grads_stored_in_bucket(reduce_rank=None, last_bucket=True) + self._reduce_grads_stored_in_bucket(self._non_moe_bucket_store, reduce_rank=None, last_bucket=True) + self._reduce_grads_stored_in_bucket(self._moe_bucket_store, reduce_rank=None, last_bucket=True) # compute norm for gradients in the before bucket groups_norms = [] diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index adab6c9..7c46b83 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -38,12 +38,16 @@ class BucketStore(BaseStore): self._grads = dict() self._params = dict() self._num_elements_in_bucket = dict() + self._dp_parallel_mode = dp_parallel_mode self.reset() def num_elements_in_bucket(self, reduce_rank: int = None): return self._num_elements_in_bucket[reduce_rank] + def get_dp_parallel_mode(self): + return self._dp_parallel_mode + def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): self._num_elements_in_bucket[reduce_rank] += num_elements diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index e423ea6..b7a369a 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -80,7 +80,7 @@ def initialize_model(): # This sync is very important, cause the model weights kept in optimizer are copied # from the origin parameters in the memory, so we should make sure the dp sync # does not influence the model weights in optimizer be different with the origin parameters. - sync_model_param(model, parallel_mode=ParallelMode.DATA) + sync_model_param(model) # This function is needed to make sure parameters that are not splitted by tensor parallelism are # the same across tensor parallelism. diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 3dd57eb..566bb0f 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -258,7 +258,14 @@ def save_model_checkpoint(folder, model): llm_save(topo_fp, saved_obj=topo) # try to save expert parameter to separate files if model have moe layer - try_save_moe_checkpoint(folder, model, tp_rank, pp_rank) + expert_dp_size = gpc.get_world_size(ParallelMode.EXPERT_DATA) + expert_dp_rank = gpc.get_local_rank(ParallelMode.EXPERT_DATA) + should_save_rank_pair.clear() + for i in range(tp_size): + should_save_rank_pair.add((i, i % expert_dp_size)) + + if (tp_rank, expert_dp_rank) in should_save_rank_pair: + try_save_moe_checkpoint(folder, model, tp_rank, pp_rank) torch.distributed.barrier() diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index b7e3b86..3a10227 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -12,48 +12,23 @@ def is_model_parallel_parameter(p): return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) -def sync_model_param(model, parallel_mode): +def sync_model_param(model): r"""Make sure data parameters are consistent during Data Parallel Mode. Args: model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. - parallel_mode (:class:`internlm.core.context.ParallelMode`): Parallel mode to be checked. """ - if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: + if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1: + sync_moe_param = ( + gpc.is_initialized(ParallelMode.EXPERT_DATA) and gpc.get_world_size(ParallelMode.EXPERT_DATA) > 1 + ) for param in model.parameters(): - if is_moe_param(param): - # TODO: moe expert param need to sync in expert data parallel group - # now we do not support expert data parallel - pass + if sync_moe_param and is_moe_param(param): + ranks = gpc.get_ranks_in_group(ParallelMode.EXPERT_DATA) + dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.EXPERT_DATA)) else: - ranks = gpc.get_ranks_in_group(parallel_mode) - dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) - - -def sync_tensor(tensor, parallel_mode): - r"""Make sure data tensor(parameters) are consistent during Data and Expert Parallel Mode. - - Args: - tensor (:class:`torch.Tensor`): A parameters you check the consistency. - parallel_mode (:class:`internlm.core.context.ParallelMode`): Parallel mode to be checked. - """ - if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: - ranks = gpc.get_ranks_in_group(parallel_mode) - dist.broadcast(tensor, src=ranks[0], group=gpc.get_group(parallel_mode)) - - -# TODO: will be used in expert data parallel, may can also used in sync_model_param_within_tp -def sync_model_param_with_ep(model): - r"""Make sure data parameters are consistent during Data Parallel Mode. - - Args: - model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. - """ - for param in model.parameters(): - if is_moe_param(param): - sync_tensor(param, ParallelMode.EXPERT_DATA) - else: - sync_tensor(param, ParallelMode.DATA) + ranks = gpc.get_ranks_in_group(ParallelMode.DATA) + dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA)) def sync_model_param_within_tp(model):