diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 660cc55..985e57f 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -319,6 +319,13 @@ def args_sanity_check(): if "moe_loss_coeff" not in gpc.config.loss: gpc.config.loss._add_item("moe_loss_coeff", 1.0) + # moe not support overlap and zero1.5 for now + if hasattr(gpc.config.model, "num_experts"): + assert ( + not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param + ), "not support overlap and moe at the same time" + assert gpc.config.parallel.zero1 == -1, "moe only support zero1, set zero1=-1 can fix this" + def launch( config: Union[str, Path, Config, Dict], diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 6894945..b2680ed 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -3,7 +3,6 @@ import math from functools import partial -from itertools import product import torch import torch.distributed as dist @@ -11,7 +10,6 @@ from torch.optim import Optimizer from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import is_moe_param from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, @@ -116,16 +114,15 @@ class HybridZeroOptimizer(BaseOptimizer): super().__init__(optim=optimizer) self._cpu_offload = cpu_offload - self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1) - self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1) - self._broadcast_parallel_mode = ParallelMode.ZERO1 + self._zero_local_rank = [] + self._zero_world_size = [] + self._broadcast_parallel_mode = [] # ParameterStore will manage the tensor buffers used for zero # it will not manage the tensors used by mixed precision training self._param_store = ParameterStore(ParallelMode.ZERO1) self._grad_store = GradientStore(ParallelMode.DATA) - self._non_moe_bucket_store = BucketStore(ParallelMode.DATA) - self._moe_bucket_store = BucketStore(ParallelMode.EXPERT_DATA) + self._bucket_store = [] self._bucket_in_progress = [] # fp16 and fp32 params for mixed precision training @@ -163,7 +160,7 @@ class HybridZeroOptimizer(BaseOptimizer): f"gpus-{gpc.get_world_size(ParallelMode.GLOBAL)}_" + f"pp-{gpc.get_local_rank(ParallelMode.PIPELINE)}_" + f"tp-{gpc.get_local_rank(ParallelMode.TENSOR)}_" - + f"zo-{self._zero_local_rank}.pt" + + f"zo-{gpc.get_local_rank(ParallelMode.ZERO1)}.pt" ) self.params_per_rank_id_dict = [] self._param_bcast_sync_handler = param_bcast_sync_handler @@ -182,10 +179,23 @@ class HybridZeroOptimizer(BaseOptimizer): # add the fp16 params to fp16_param_groups for bookkeeping self._fp16_param_groups[group_id] = group_params + # to find real zero mode. if zero is not used, set all param group as ParallelMode.ZERO1 + # if zero is used, expert dp group will use ParallelMode.EXPERT_DATA as the real zero mode + zero_mode = ( + ParallelMode.ZERO1 + if param_group["dp_mode"] == gpc.get_world_size(ParallelMode.ZERO1) == 1 or ParallelMode.DATA + else ParallelMode.EXPERT_DATA + ) + self._zero_local_rank.append(gpc.get_local_rank(zero_mode)) + self._zero_world_size.append(gpc.get_world_size(zero_mode)) + # TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name + self._broadcast_parallel_mode.append(zero_mode) + self._bucket_store.append(BucketStore(group_id, param_group["dp_mode"])) + # assign parameters to ranks the params in the list are sorted - params_per_rank, no_params_ranks = self._partition_param_list(param_group) + params_per_rank, no_params_ranks = self._partition_param_list(group_id, param_group) self.param_group_no_params_ranks.append(no_params_ranks) - self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks) + self.param_group_has_params.append(self._zero_local_rank[group_id] not in no_params_ranks) # store the mapping between param to rank each param should belong to only one rank. # we can skip the moe param and do not keep them in _param_store to save memory @@ -204,7 +214,7 @@ class HybridZeroOptimizer(BaseOptimizer): param.data = param.data.cpu() # flatten the reordered tensors - for rank in range(self._zero_world_size): + for rank in range(self._zero_world_size[group_id]): # No flat fp16 buffer is allocated if the process has no parameters. if rank not in self.param_group_no_params_ranks[group_id]: tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) @@ -218,7 +228,7 @@ class HybridZeroOptimizer(BaseOptimizer): # No flat fp32 buffer is allocated if the process has no parameters. if self.param_group_has_params[group_id]: fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group( - self._zero_local_rank, group_id + self._zero_local_rank[group_id], group_id ) fp32_flat_current_rank = fp16_flat_current_rank.float() device = "cpu" if self._cpu_offload else get_current_device() @@ -263,44 +273,35 @@ class HybridZeroOptimizer(BaseOptimizer): def num_param_groups(self): return len(self._fp16_param_groups) - def _partition_param_list(self, param_group): + def _partition_param_list(self, group_id, param_group): no_params_ranks = [] - params_per_rank = [[] for _ in range(self._zero_world_size)] - numel_per_rank = [0 for _ in range(self._zero_world_size)] - self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)]) + params_per_rank = [[] for _ in range(self._zero_world_size[group_id])] + numel_per_rank = [0 for _ in range(self._zero_world_size[group_id])] + self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size[group_id])]) param_list = param_group["params"] - if self._is_moe_group(param_group): - # for moe group, we do not need to partition the params, just add current - # params to params_per_rank[_zero_local_rank] - params_per_rank[self._zero_local_rank] = list(param_list) - self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None) - no_params_ranks = list(range(self._zero_world_size)) - no_params_ranks.pop(self._zero_local_rank) + sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) + for i, param in enumerate(sorted_params): + global_id = str(i) + for j in range(len(param.size())): + global_id = "_".join([global_id, str(param.size()[j])]) + if self._overlap_sync_param: + rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param) + else: + rank_to_go = numel_per_rank.index(min(numel_per_rank)) + params_per_rank[rank_to_go].append(param) + self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) + numel_per_rank[rank_to_go] += param.numel() - else: - sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) - for i, param in enumerate(sorted_params): - global_id = str(i) - for j in range(len(param.size())): - global_id = "_".join([global_id, str(param.size()[j])]) - if self._overlap_sync_param: - rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param) - else: - rank_to_go = numel_per_rank.index(min(numel_per_rank)) - params_per_rank[rank_to_go].append(param) - self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) - numel_per_rank[rank_to_go] += param.numel() + # check whether any rank is not assigned to parameters. + for rank, params in enumerate(params_per_rank): + if len(params) == 0: + no_params_ranks.append(rank) - # check whether any rank is not assigned to parameters. - for rank, params in enumerate(params_per_rank): - if len(params) == 0: - no_params_ranks.append(rank) - - if gpc.is_rank_for_log(): - logger.info( # pylint: disable=W1203 - f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}" - ) + if gpc.is_rank_for_log(): + logger.info( # pylint: disable=W1203 + f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}" + ) return params_per_rank, set(no_params_ranks) @@ -313,6 +314,7 @@ class HybridZeroOptimizer(BaseOptimizer): def _is_gate_group(self, param_group): return "gate" in param_group.keys() and param_group["gate"] + # TODO check expert dp is correct when enable moe and overlap both def _attach_reduction_hook(self): # we iterate over the fp16 params # on each param, we register a hook to its AccumulateGrad object @@ -346,16 +348,28 @@ class HybridZeroOptimizer(BaseOptimizer): _define_and_attach(param, reduce_rank) + def belongs_to_current_rank(self, param) -> bool: + """ + Check whether a parameter is supposed to be updated by the process of the current rank + + :param tensor: A :class:`torch.Tensor` object + :type tensor: torch.Tensor + + :return: True if the parameter should be updated by the current rank. Otherwise false. + :rtype: bool + """ + tensor_rank = self._param_store.get_param_rank(param) + group_id = getattr(param, "group_id") + return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id]) + def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None): param_size = param.numel() # check if the bucket is full # if full, will reduce the grads already in the bucket # after reduction, the bucket will be empty - if is_moe_param(param): - current_bucket = self._moe_bucket_store - else: - current_bucket = self._non_moe_bucket_store + group_id = getattr(param, "group_id") + current_bucket = self._bucket_store[group_id] if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False) @@ -382,6 +396,7 @@ class HybridZeroOptimizer(BaseOptimizer): reduce_rank=reduce_rank, grads=current_bucket.get_grad(reduce_rank=reduce_rank), bucket_size=current_bucket.num_elements_in_bucket(reduce_rank), + group_id=current_bucket.get_param_group_id(), dp_parallel_mode=current_bucket.get_dp_parallel_mode(), ) @@ -402,14 +417,14 @@ class HybridZeroOptimizer(BaseOptimizer): # update the flag self._param_store.set_param_reduction_state(param, True) - if self._param_store.belongs_to_current_rank(param): + if self.belongs_to_current_rank(param): self._param_store.add_reduced_param_for_compute_norm(param, last_bucket) else: self._param_store.add_previous_reduced_param(param) current_bucket.reset_by_rank(reduce_rank) - def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, dp_parallel_mode): + def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, group_id, dp_parallel_mode): grad_buckets_by_dtype = split_half_float_double(grads) next_bucket_list = [] # add parameters into bucket for reduction @@ -418,7 +433,9 @@ class HybridZeroOptimizer(BaseOptimizer): for tensor in tensor_list: param_bucket.add_to_bucket(tensor, allow_oversize=True) if not param_bucket.is_empty(): - self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank, dp_parallel_mode=dp_parallel_mode) + self._reduce_and_copy( + bucket=param_bucket, reduce_rank=reduce_rank, group_id=group_id, dp_parallel_mode=dp_parallel_mode + ) next_bucket_list.append(param_bucket) # wait for the completion of previouce bucket list reduction, and do unflatten_and_copy() @@ -433,7 +450,7 @@ class HybridZeroOptimizer(BaseOptimizer): # after the completion of bucket list reduction, add new buckets into _bucket_in_progress self._bucket_in_progress = next_bucket_list.copy() - def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, dp_parallel_mode): + def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, group_id, dp_parallel_mode): # flatten the tensors and do allreduce bucket.flatten() bucket.commu_handle = reduce_tensor( @@ -444,7 +461,7 @@ class HybridZeroOptimizer(BaseOptimizer): ) # update the reduced tensor - if reduce_rank is None or reduce_rank == self._zero_local_rank: + if reduce_rank is None or reduce_rank == self._zero_local_rank[group_id]: bucket.set_unflatten_and_copy_flag(flag=True) def _has_inf_or_nan(self, tensor): @@ -473,8 +490,8 @@ class HybridZeroOptimizer(BaseOptimizer): avg_gradients = self._grad_store._averaged_gradients for group_id in range(self.num_param_groups): # the following operations are performed only on the rank to which parameters are assigned. - if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: - param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank, group_id) + if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]: + param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank[group_id], group_id) if group_id not in avg_gradients: avg_gradients[group_id] = [] @@ -538,37 +555,11 @@ class HybridZeroOptimizer(BaseOptimizer): parameters=params, last_stage=last_stage, previous_norm=previous_norm, + zero_mode=self._broadcast_parallel_mode[group_id], ) return norm - def _compute_norm_with_moe_group(self, group_id): - params = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank) - # we do not get the average grad for moe parameters, so we have to constuct the gradients list here. - grads = [p.grad for p in params] - - if len(params) == 0: - grads = [self.padding_grad] - params = [self.padding_tensor] - - norm = 0 - if self._clip_grad_norm > 0: - norm = compute_norm( - gradients=grads, - parameters=params, - last_stage=True, - is_moe_group=True, - ) - - # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce - # model and zero have been reduced!!! - pg = gpc.get_group(ParallelMode.DATA) - scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) - scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) - dist.all_reduce(scaled_norm_tensor, group=pg) - all_groups_norm = scaled_norm_tensor.item() - return all_groups_norm - @llm_timeout(func_name="optim_step") def step(self, closure=None): """Performs a single optimization step. @@ -591,16 +582,13 @@ class HybridZeroOptimizer(BaseOptimizer): self._store_and_try_reduce_grads_by_bucket(param) # we need to reduce the gradients left in the communication bucket - self._reduce_grads_stored_in_bucket(self._non_moe_bucket_store, reduce_rank=None, last_bucket=True) - self._reduce_grads_stored_in_bucket(self._moe_bucket_store, reduce_rank=None, last_bucket=True) + for group_id in range(self.num_param_groups): + self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True) # compute norm for gradients in the before bucket groups_norms = [] for group_id in range(self.num_param_groups): - if self._is_moe_group(self.optim.param_groups[group_id]): - groups_norms.append(None) - else: - groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) + groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) # clear reduced grads # grads in the last bucket is reduced @@ -616,15 +604,22 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id in range(self.num_param_groups): group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default" group_name = f"{group_id}_{group_name}" + total_norms[group_name] = self._compute_norm_with_stage( + group_id=group_id, + last_bucket=True, + last_stage=True, + previous_norm=groups_norms[group_id], + ) + + # Need to allreduce(avg) the norms across different ranks because moe params will not be synced + # during allreduce if self._is_moe_group(self.optim.param_groups[group_id]): - total_norms[group_name] = self._compute_norm_with_moe_group(group_id=group_id) - else: - total_norms[group_name] = self._compute_norm_with_stage( - group_id=group_id, - last_bucket=True, - last_stage=True, - previous_norm=groups_norms[group_id], - ) + # model and zero have been reduced!!! + pg = gpc.get_group(ParallelMode.EXPERT) + scaled_norm = total_norms[group_name] * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT)) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) + dist.all_reduce(scaled_norm_tensor, group=pg) + total_norms[group_name] = scaled_norm_tensor.item() timer("sync_grad").start() self._sync_grad() @@ -746,7 +741,7 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id in range(len(self._fp16_param_groups)): if self.param_group_has_params[group_id]: fp16_param = self._param_store.get_flat_fp16_param_by_rank_group( - rank=self._zero_local_rank, group_id=group_id + rank=self._zero_local_rank[group_id], group_id=group_id ) fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp16_param.data.copy_(fp32_param) @@ -766,27 +761,26 @@ class HybridZeroOptimizer(BaseOptimizer): def broadcast_params(self): handles = [] - for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)): - if self._is_moe_group(self.optim.param_groups[group_id]): - continue - # The following operations are performed only on the rank to which parameters are assigned. - if rank in self.param_group_no_params_ranks[group_id]: - continue - fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) - # grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank - # assert grank == rank, f"{grank} == {rank}" - g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank] - handle = dist.broadcast( - fp16_param, - src=g_rank, - group=gpc.get_group(ParallelMode.ZERO1), - async_op=True, - ) + for group_id in range(self.num_param_groups): + for rank in range(self._zero_world_size[group_id]): + # The following operations are performed only on the rank to which parameters are assigned. + if rank in self.param_group_no_params_ranks[group_id]: + continue + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) + # grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank + # assert grank == rank, f"{grank} == {rank}" + g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank] + handle = dist.broadcast( + fp16_param, + src=g_rank, + group=gpc.get_group(self._broadcast_parallel_mode[group_id]), + async_op=True, + ) - if self._overlap_sync_param: - self._param_bcast_sync_handler.add_bcast_handle(rank, handle) - else: - handles.append(handle) + if self._overlap_sync_param: + self._param_bcast_sync_handler.add_bcast_handle(rank, handle) + else: + handles.append(handle) for handle in handles: handle.wait() @@ -802,7 +796,7 @@ class HybridZeroOptimizer(BaseOptimizer): # check for overflow for group_id in range(len(self._fp16_param_groups)): # The following operations are performed only on the rank to which parameters are assigned. - if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: + if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]: for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id): if avg_grad is not None and has_inf_or_nan(avg_grad): self._found_overflow.fill_(1.0) @@ -843,7 +837,7 @@ class HybridZeroOptimizer(BaseOptimizer): flat_fp32_weights = {} for group_id, param in self._fp32_flat_param_groups_of_current_rank.items(): - if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: + if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]: assert param.grad is None flat_fp32_weights[group_id] = param states["flat_fp32_weights"] = flat_fp32_weights @@ -863,7 +857,7 @@ class HybridZeroOptimizer(BaseOptimizer): flat_fp32_weights = states["flat_fp32_weights"] assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank) for group_id, param in flat_fp32_weights.items(): - if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: + if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]: self_param = self._fp32_flat_param_groups_of_current_rank[group_id] assert ( self_param.shape == param.shape @@ -872,9 +866,9 @@ class HybridZeroOptimizer(BaseOptimizer): # Load the fp16 model weights. for group_id in range(len(self._fp16_param_groups)): - if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: + if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]: fp16_param = self._param_store.get_flat_fp16_param_by_rank_group( - rank=self._zero_local_rank, group_id=group_id + rank=self._zero_local_rank[group_id], group_id=group_id ) fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp16_param.data.copy_(fp32_param) @@ -891,7 +885,7 @@ def reload_zero_fp32_buff(optimizer): if optimizer.param_group_has_params[group_id]: # flatten fp16 params have already been updated by 'load_model_checkpoint' fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group( - optimizer._zero_local_rank, group_id + optimizer._zero_local_rank[group_id], group_id ) # param_group["params"] is fp32 flatten optimizer states of this zero rank. param_group["params"][0].data.copy_(fp16_flat_current_rank.float()) diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index 7c46b83..33380eb 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -33,18 +33,22 @@ class BucketStore(BaseStore): Bucket Store """ - def __init__(self, dp_parallel_mode): + def __init__(self, group_id, dp_parallel_mode): super().__init__(dp_parallel_mode) self._grads = dict() self._params = dict() self._num_elements_in_bucket = dict() self._dp_parallel_mode = dp_parallel_mode + self._group_id = group_id self.reset() def num_elements_in_bucket(self, reduce_rank: int = None): return self._num_elements_in_bucket[reduce_rank] + def get_param_group_id(self): + return self._group_id + def get_dp_parallel_mode(self): return self._dp_parallel_mode @@ -182,20 +186,6 @@ class ParameterStore(BaseStore): """ return self._fp16_param_to_rank[tensor] - def belongs_to_current_rank(self, tensor) -> bool: - """ - Check whether a parameter is supposed to be updated by the process of the current rank - - :param tensor: A :class:`torch.Tensor` object - :type tensor: torch.Tensor - - :return: True if the parameter should be updated by the current rank. Otherwise false. - :rtype: bool - """ - - tensor_rank = self._fp16_param_to_rank[tensor] - return tensor_rank == self._local_rank - def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None: if rank not in self._rank_groupid_to_fp16_param_list: self._rank_groupid_to_fp16_param_list[rank] = dict() diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 9dac6f8..f4816a7 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -209,7 +209,9 @@ def calc_lp(grads, norm_type): return norm -def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, is_moe_group=False): +def compute_norm( + gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, zero_mode=ParallelMode.ZERO1 +): """Get the norm Arguments: gradients (Iterable[Tensor]): The gradient value. @@ -302,8 +304,7 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no # This is because we use zero1, so we need to use this reduction. # TODO: Check zero group to be a subset of dp group. - if not is_moe_group: - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1)) + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode)) if torch.is_tensor(total_norm): total_norm = total_norm.item() diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 0e249fe..9096a2a 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -2,6 +2,7 @@ from typing import Dict, Tuple import torch +from internlm.core.context.parallel_context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param @@ -37,13 +38,13 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) # create new groups for fp32, norm, moe gate and moe expert new_groups = {} - new_groups["fp32"] = {"name": "fp32", "params": []} + new_groups["fp32"] = {"name": "fp32", "params": [], "dp_mode": ParallelMode.DATA} if gpc.config.model.get("num_experts", 0) > 1: # norm and gate are special group to force sync (when enable MoE). for key in ["gate", "norm"]: - new_groups[key] = {"name": key, key: True, "params": []} + new_groups[key] = {"name": key, key: True, "params": [], "dp_mode": ParallelMode.DATA} for key in gpc.expert_parallel_group_names: - new_groups[key] = {"name": key, "moe": True, "params": []} + new_groups[key] = {"name": key, "moe": True, "params": [], "dp_mode": ParallelMode.EXPERT_DATA} for pgroup in param_groups: # copy attribute from origin group, we assume the input param_groups only @@ -72,6 +73,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) # bf16 param group, which is the first group in the param groups pgroup["params"] = origin_params + pgroup["dp_mode"] = ParallelMode.DATA # param groups may contain empty groups, such as fp32 param_groups.extend(new_groups.values()) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 566bb0f..00e7436 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -392,13 +392,14 @@ def save_optimizer_checkpoint(optim, state_path): zero_rank = gpc.get_local_rank(ParallelMode.ZERO1) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + zero_size = gpc.get_world_size(ParallelMode.ZERO1) tp_size = gpc.get_world_size(ParallelMode.TENSOR) pp_size = gpc.get_world_size(ParallelMode.PIPELINE) fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt" states = optim.state_dict() if isinstance(optim, HybridZeroOptimizer): - if gpc.get_global_rank() < optim.zero_world_size * tp_size * pp_size: + if gpc.get_global_rank() < zero_size * tp_size * pp_size: llm_save(os.path.join(state_path, fp), states) if "zero_devide_optim_plan" in states: params_per_rank_id_dict = states.pop("zero_devide_optim_plan") diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 3a10227..9efef10 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -5,7 +5,6 @@ import torch.distributed as dist from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import is_moe_param def is_model_parallel_parameter(p): @@ -23,7 +22,7 @@ def sync_model_param(model): gpc.is_initialized(ParallelMode.EXPERT_DATA) and gpc.get_world_size(ParallelMode.EXPERT_DATA) > 1 ) for param in model.parameters(): - if sync_moe_param and is_moe_param(param): + if sync_moe_param and getattr(param, "is_expert", False): ranks = gpc.get_ranks_in_group(ParallelMode.EXPERT_DATA) dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.EXPERT_DATA)) else: diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py index 956880b..0804455 100644 --- a/tests/test_utils/test_model_checkpoint.py +++ b/tests/test_utils/test_model_checkpoint.py @@ -86,13 +86,13 @@ ckpt_config_list = [ def overwrite_optim_state(optim, set_value): if isinstance(optim, HybridZeroOptimizer): for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items(): - if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]: + if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]: # p.copy_(torch.full_like(p, set_value, dtype=p.dtype)) p.data.fill_(set_value) for group_id in range(len(optim._fp16_param_groups)): - if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]: + if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]: fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group( - rank=optim._zero_local_rank, group_id=group_id + rank=optim._zero_local_rank[group_id], group_id=group_id ) fp16_p.fill_(set_value) else: @@ -109,7 +109,7 @@ def compare_optim_state(optim1, optim2): fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2): re &= group_id_1 == group_id_2 - if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]: + if optim1.zero_local_rank[group_id_1] not in optim1.param_group_no_params_ranks[group_id_1]: re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2]) else: for group1, group2 in zip(optim1.param_groups, optim2.param_groups): @@ -122,12 +122,12 @@ def compare_optim_value(optim, value): re = True if isinstance(optim, HybridZeroOptimizer): for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items(): - if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]: + if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]: re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype)) for group_id in range(len(optim._fp16_param_groups)): - if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]: + if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]: fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group( - rank=optim._zero_local_rank, group_id=group_id + rank=optim._zero_local_rank[group_id], group_id=group_id ) re &= torch.equal(fp16_p, torch.full_like(fp16_p, value, dtype=fp16_p.dtype)) else: