add codes for reducing moe parameters in expert data group

pull/375/head
Wenwen Qu 2023-08-22 17:29:28 +08:00
parent 12c614db94
commit 14a81e5c1d
2 changed files with 33 additions and 25 deletions

View File

@ -124,7 +124,8 @@ class HybridZeroOptimizer(BaseOptimizer):
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(ParallelMode.ZERO1)
self._grad_store = GradientStore(ParallelMode.DATA)
self._bucket_store = BucketStore(ParallelMode.DATA)
self._non_moe_bucket_store = BucketStore(ParallelMode.DATA)
self._moe_bucket_store = BucketStore(ParallelMode.EXPERT_DATA)
# fp16 and fp32 params for mixed precision training
self._fp16_param_groups = dict()
@ -263,12 +264,6 @@ class HybridZeroOptimizer(BaseOptimizer):
def num_param_groups(self):
return len(self._fp16_param_groups)
def _get_real_dp_process_group(self, param_groups):
if "moe" in param_groups.keys() and param_groups["moe"]:
return ParallelMode.EXPERT_DATA
else:
return ParallelMode.DATA
def _partition_param_list(self, param_group):
no_params_ranks = []
params_per_rank = [[] for _ in range(self._zero_world_size)]
@ -317,7 +312,7 @@ class HybridZeroOptimizer(BaseOptimizer):
param_group = self._fp16_param_groups[group_id]
for param in param_group:
# we should not reduce the param in moe
if param.requires_grad and not is_moe_param(param):
if param.requires_grad:
reduce_rank = None
def _define_and_attach(param, reduce_rank=None):
@ -347,8 +342,13 @@ class HybridZeroOptimizer(BaseOptimizer):
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_stored_in_bucket(reduce_rank, last_bucket=False)
if is_moe_param(param):
current_bucket = self._moe_bucket_store
else:
current_bucket = self._non_moe_bucket_store
if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False)
# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
@ -362,19 +362,20 @@ class HybridZeroOptimizer(BaseOptimizer):
# the param must have grad for reduction
assert param.grad is not None, f"Parameter of size ({param.size()}) has None grad, cannot be reduced"
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
self._bucket_store.add_grad(param.grad, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)
current_bucket.add_num_elements_in_bucket(param_size, reduce_rank)
current_bucket.add_grad(param.grad, reduce_rank)
current_bucket.add_param(param, reduce_rank)
def _reduce_grads_stored_in_bucket(self, reduce_rank=None, last_bucket=False):
def _reduce_grads_stored_in_bucket(self, current_bucket, reduce_rank=None, last_bucket=False):
# reduce grads
self._reduce_grads_by_rank(
reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank),
grads=current_bucket.get_grad(reduce_rank=reduce_rank),
bucket_size=current_bucket.num_elements_in_bucket(reduce_rank),
dp_parallel_mode=current_bucket.get_dp_parallel_mode(),
)
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
params_in_bucket = current_bucket.get_param(reduce_rank=reduce_rank)
for param in params_in_bucket:
# the is_param_reduced flag should be False showing that
@ -396,9 +397,9 @@ class HybridZeroOptimizer(BaseOptimizer):
else:
self._param_store.add_previous_reduced_param(param)
self._bucket_store.reset_by_rank(reduce_rank)
current_bucket.reset_by_rank(reduce_rank)
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, dp_parallel_mode):
grad_buckets_by_dtype = split_half_float_double(grads)
for tensor_list in grad_buckets_by_dtype:
@ -406,12 +407,14 @@ class HybridZeroOptimizer(BaseOptimizer):
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if param_bucket.is_full_or_oversized():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
self._reduce_and_copy(
bucket=param_bucket, reduce_rank=reduce_rank, dp_parallel_mode=dp_parallel_mode
)
param_bucket.empty()
if not param_bucket.is_empty():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank, dp_parallel_mode=dp_parallel_mode)
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, dp_parallel_mode):
if self._overlap_communication:
stream = self._comm_stream
stream.synchronize()
@ -422,7 +425,7 @@ class HybridZeroOptimizer(BaseOptimizer):
with torch.cuda.stream(stream):
flat = bucket.flatten()
reduced_flat = reduce_tensor(
tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=ParallelMode.DATA
tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=dp_parallel_mode
)
# update the reduced tensor
@ -539,11 +542,12 @@ class HybridZeroOptimizer(BaseOptimizer):
for group_id in range(len(self._fp16_param_groups)):
for param in self._fp16_param_groups[group_id]:
# we should not reduce the param in moe
if param.grad is not None and not is_moe_param(param):
if param.grad is not None:
self._store_and_try_reduce_grads_by_bucket(param)
# we need to reduce the gradients left in the communication bucket
self._reduce_grads_stored_in_bucket(reduce_rank=None, last_bucket=True)
self._reduce_grads_stored_in_bucket(self._non_moe_bucket_store, reduce_rank=None, last_bucket=True)
self._reduce_grads_stored_in_bucket(self._moe_bucket_store, reduce_rank=None, last_bucket=True)
# compute norm for gradients in the before bucket
groups_norms = []

View File

@ -38,12 +38,16 @@ class BucketStore(BaseStore):
self._grads = dict()
self._params = dict()
self._num_elements_in_bucket = dict()
self._dp_parallel_mode = dp_parallel_mode
self.reset()
def num_elements_in_bucket(self, reduce_rank: int = None):
return self._num_elements_in_bucket[reduce_rank]
def get_dp_parallel_mode(self):
return self._dp_parallel_mode
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
self._num_elements_in_bucket[reduce_rank] += num_elements