mirror of https://github.com/InternLM/InternLM
add codes for reducing moe parameters in expert data group
parent
12c614db94
commit
14a81e5c1d
|
@ -124,7 +124,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# it will not manage the tensors used by mixed precision training
|
||||
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
||||
self._grad_store = GradientStore(ParallelMode.DATA)
|
||||
self._bucket_store = BucketStore(ParallelMode.DATA)
|
||||
self._non_moe_bucket_store = BucketStore(ParallelMode.DATA)
|
||||
self._moe_bucket_store = BucketStore(ParallelMode.EXPERT_DATA)
|
||||
|
||||
# fp16 and fp32 params for mixed precision training
|
||||
self._fp16_param_groups = dict()
|
||||
|
@ -263,12 +264,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
def num_param_groups(self):
|
||||
return len(self._fp16_param_groups)
|
||||
|
||||
def _get_real_dp_process_group(self, param_groups):
|
||||
if "moe" in param_groups.keys() and param_groups["moe"]:
|
||||
return ParallelMode.EXPERT_DATA
|
||||
else:
|
||||
return ParallelMode.DATA
|
||||
|
||||
def _partition_param_list(self, param_group):
|
||||
no_params_ranks = []
|
||||
params_per_rank = [[] for _ in range(self._zero_world_size)]
|
||||
|
@ -317,7 +312,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
param_group = self._fp16_param_groups[group_id]
|
||||
for param in param_group:
|
||||
# we should not reduce the param in moe
|
||||
if param.requires_grad and not is_moe_param(param):
|
||||
if param.requires_grad:
|
||||
reduce_rank = None
|
||||
|
||||
def _define_and_attach(param, reduce_rank=None):
|
||||
|
@ -347,8 +342,13 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# check if the bucket is full
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._reduce_grads_stored_in_bucket(reduce_rank, last_bucket=False)
|
||||
if is_moe_param(param):
|
||||
current_bucket = self._moe_bucket_store
|
||||
else:
|
||||
current_bucket = self._non_moe_bucket_store
|
||||
|
||||
if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False)
|
||||
|
||||
# the param must not be reduced to ensure correctness
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
|
@ -362,19 +362,20 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# the param must have grad for reduction
|
||||
assert param.grad is not None, f"Parameter of size ({param.size()}) has None grad, cannot be reduced"
|
||||
|
||||
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||
self._bucket_store.add_grad(param.grad, reduce_rank)
|
||||
self._bucket_store.add_param(param, reduce_rank)
|
||||
current_bucket.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||
current_bucket.add_grad(param.grad, reduce_rank)
|
||||
current_bucket.add_param(param, reduce_rank)
|
||||
|
||||
def _reduce_grads_stored_in_bucket(self, reduce_rank=None, last_bucket=False):
|
||||
def _reduce_grads_stored_in_bucket(self, current_bucket, reduce_rank=None, last_bucket=False):
|
||||
# reduce grads
|
||||
self._reduce_grads_by_rank(
|
||||
reduce_rank=reduce_rank,
|
||||
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
|
||||
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank),
|
||||
grads=current_bucket.get_grad(reduce_rank=reduce_rank),
|
||||
bucket_size=current_bucket.num_elements_in_bucket(reduce_rank),
|
||||
dp_parallel_mode=current_bucket.get_dp_parallel_mode(),
|
||||
)
|
||||
|
||||
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
|
||||
params_in_bucket = current_bucket.get_param(reduce_rank=reduce_rank)
|
||||
|
||||
for param in params_in_bucket:
|
||||
# the is_param_reduced flag should be False showing that
|
||||
|
@ -396,9 +397,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
else:
|
||||
self._param_store.add_previous_reduced_param(param)
|
||||
|
||||
self._bucket_store.reset_by_rank(reduce_rank)
|
||||
current_bucket.reset_by_rank(reduce_rank)
|
||||
|
||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
|
||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, dp_parallel_mode):
|
||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
||||
|
||||
for tensor_list in grad_buckets_by_dtype:
|
||||
|
@ -406,12 +407,14 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
for tensor in tensor_list:
|
||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||
if param_bucket.is_full_or_oversized():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
self._reduce_and_copy(
|
||||
bucket=param_bucket, reduce_rank=reduce_rank, dp_parallel_mode=dp_parallel_mode
|
||||
)
|
||||
param_bucket.empty()
|
||||
if not param_bucket.is_empty():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank, dp_parallel_mode=dp_parallel_mode)
|
||||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, dp_parallel_mode):
|
||||
if self._overlap_communication:
|
||||
stream = self._comm_stream
|
||||
stream.synchronize()
|
||||
|
@ -422,7 +425,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
with torch.cuda.stream(stream):
|
||||
flat = bucket.flatten()
|
||||
reduced_flat = reduce_tensor(
|
||||
tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=ParallelMode.DATA
|
||||
tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=dp_parallel_mode
|
||||
)
|
||||
|
||||
# update the reduced tensor
|
||||
|
@ -539,11 +542,12 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
for group_id in range(len(self._fp16_param_groups)):
|
||||
for param in self._fp16_param_groups[group_id]:
|
||||
# we should not reduce the param in moe
|
||||
if param.grad is not None and not is_moe_param(param):
|
||||
if param.grad is not None:
|
||||
self._store_and_try_reduce_grads_by_bucket(param)
|
||||
|
||||
# we need to reduce the gradients left in the communication bucket
|
||||
self._reduce_grads_stored_in_bucket(reduce_rank=None, last_bucket=True)
|
||||
self._reduce_grads_stored_in_bucket(self._non_moe_bucket_store, reduce_rank=None, last_bucket=True)
|
||||
self._reduce_grads_stored_in_bucket(self._moe_bucket_store, reduce_rank=None, last_bucket=True)
|
||||
|
||||
# compute norm for gradients in the before bucket
|
||||
groups_norms = []
|
||||
|
|
|
@ -38,12 +38,16 @@ class BucketStore(BaseStore):
|
|||
self._grads = dict()
|
||||
self._params = dict()
|
||||
self._num_elements_in_bucket = dict()
|
||||
self._dp_parallel_mode = dp_parallel_mode
|
||||
|
||||
self.reset()
|
||||
|
||||
def num_elements_in_bucket(self, reduce_rank: int = None):
|
||||
return self._num_elements_in_bucket[reduce_rank]
|
||||
|
||||
def get_dp_parallel_mode(self):
|
||||
return self._dp_parallel_mode
|
||||
|
||||
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
|
||||
self._num_elements_in_bucket[reduce_rank] += num_elements
|
||||
|
||||
|
|
Loading…
Reference in New Issue