diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 0d5ce4b..7dd592e 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -124,7 +124,8 @@ class HybridZeroOptimizer(BaseOptimizer): # it will not manage the tensors used by mixed precision training self._param_store = ParameterStore(ParallelMode.ZERO1) self._grad_store = GradientStore(ParallelMode.DATA) - self._bucket_store = BucketStore(ParallelMode.DATA) + self._non_moe_bucket_store = BucketStore(ParallelMode.DATA) + self._moe_bucket_store = BucketStore(ParallelMode.EXPERT_DATA) # fp16 and fp32 params for mixed precision training self._fp16_param_groups = dict() @@ -263,12 +264,6 @@ class HybridZeroOptimizer(BaseOptimizer): def num_param_groups(self): return len(self._fp16_param_groups) - def _get_real_dp_process_group(self, param_groups): - if "moe" in param_groups.keys() and param_groups["moe"]: - return ParallelMode.EXPERT_DATA - else: - return ParallelMode.DATA - def _partition_param_list(self, param_group): no_params_ranks = [] params_per_rank = [[] for _ in range(self._zero_world_size)] @@ -317,7 +312,7 @@ class HybridZeroOptimizer(BaseOptimizer): param_group = self._fp16_param_groups[group_id] for param in param_group: # we should not reduce the param in moe - if param.requires_grad and not is_moe_param(param): + if param.requires_grad: reduce_rank = None def _define_and_attach(param, reduce_rank=None): @@ -347,8 +342,13 @@ class HybridZeroOptimizer(BaseOptimizer): # check if the bucket is full # if full, will reduce the grads already in the bucket # after reduction, the bucket will be empty - if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: - self._reduce_grads_stored_in_bucket(reduce_rank, last_bucket=False) + if is_moe_param(param): + current_bucket = self._moe_bucket_store + else: + current_bucket = self._non_moe_bucket_store + + if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: + self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False) # the param must not be reduced to ensure correctness is_param_reduced = self._param_store.is_param_reduced(param) @@ -362,19 +362,20 @@ class HybridZeroOptimizer(BaseOptimizer): # the param must have grad for reduction assert param.grad is not None, f"Parameter of size ({param.size()}) has None grad, cannot be reduced" - self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) - self._bucket_store.add_grad(param.grad, reduce_rank) - self._bucket_store.add_param(param, reduce_rank) + current_bucket.add_num_elements_in_bucket(param_size, reduce_rank) + current_bucket.add_grad(param.grad, reduce_rank) + current_bucket.add_param(param, reduce_rank) - def _reduce_grads_stored_in_bucket(self, reduce_rank=None, last_bucket=False): + def _reduce_grads_stored_in_bucket(self, current_bucket, reduce_rank=None, last_bucket=False): # reduce grads self._reduce_grads_by_rank( reduce_rank=reduce_rank, - grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), - bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank), + grads=current_bucket.get_grad(reduce_rank=reduce_rank), + bucket_size=current_bucket.num_elements_in_bucket(reduce_rank), + dp_parallel_mode=current_bucket.get_dp_parallel_mode(), ) - params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank) + params_in_bucket = current_bucket.get_param(reduce_rank=reduce_rank) for param in params_in_bucket: # the is_param_reduced flag should be False showing that @@ -396,9 +397,9 @@ class HybridZeroOptimizer(BaseOptimizer): else: self._param_store.add_previous_reduced_param(param) - self._bucket_store.reset_by_rank(reduce_rank) + current_bucket.reset_by_rank(reduce_rank) - def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size): + def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, dp_parallel_mode): grad_buckets_by_dtype = split_half_float_double(grads) for tensor_list in grad_buckets_by_dtype: @@ -406,12 +407,14 @@ class HybridZeroOptimizer(BaseOptimizer): for tensor in tensor_list: param_bucket.add_to_bucket(tensor, allow_oversize=True) if param_bucket.is_full_or_oversized(): - self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) + self._reduce_and_copy( + bucket=param_bucket, reduce_rank=reduce_rank, dp_parallel_mode=dp_parallel_mode + ) param_bucket.empty() if not param_bucket.is_empty(): - self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) + self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank, dp_parallel_mode=dp_parallel_mode) - def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): + def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, dp_parallel_mode): if self._overlap_communication: stream = self._comm_stream stream.synchronize() @@ -422,7 +425,7 @@ class HybridZeroOptimizer(BaseOptimizer): with torch.cuda.stream(stream): flat = bucket.flatten() reduced_flat = reduce_tensor( - tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=ParallelMode.DATA + tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=dp_parallel_mode ) # update the reduced tensor @@ -539,11 +542,12 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id in range(len(self._fp16_param_groups)): for param in self._fp16_param_groups[group_id]: # we should not reduce the param in moe - if param.grad is not None and not is_moe_param(param): + if param.grad is not None: self._store_and_try_reduce_grads_by_bucket(param) # we need to reduce the gradients left in the communication bucket - self._reduce_grads_stored_in_bucket(reduce_rank=None, last_bucket=True) + self._reduce_grads_stored_in_bucket(self._non_moe_bucket_store, reduce_rank=None, last_bucket=True) + self._reduce_grads_stored_in_bucket(self._moe_bucket_store, reduce_rank=None, last_bucket=True) # compute norm for gradients in the before bucket groups_norms = [] diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index 05a44d2..262cb1f 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -38,12 +38,16 @@ class BucketStore(BaseStore): self._grads = dict() self._params = dict() self._num_elements_in_bucket = dict() + self._dp_parallel_mode = dp_parallel_mode self.reset() def num_elements_in_bucket(self, reduce_rank: int = None): return self._num_elements_in_bucket[reduce_rank] + def get_dp_parallel_mode(self): + return self._dp_parallel_mode + def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): self._num_elements_in_bucket[reduce_rank] += num_elements