mirror of https://github.com/InternLM/InternLM
				
				
				
			fix moe bugs in zero optimizer
							parent
							
								
									3bfaad895a
								
							
						
					
					
						commit
						9ee57e6c8a
					
				| 
						 | 
				
			
			@ -166,10 +166,6 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
        # partition these param groups for data parallel training
 | 
			
		||||
        # and add buffers to parameter store for future access
 | 
			
		||||
        for group_id, param_group in enumerate(self.optim.param_groups):
 | 
			
		||||
            if "moe" in param_group.keys() and param_group["moe"]:
 | 
			
		||||
                print("true", flush=True)
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            group_params = param_group["params"]
 | 
			
		||||
 | 
			
		||||
            # add the fp16 params to fp16_param_groups for bookkeeping
 | 
			
		||||
| 
						 | 
				
			
			@ -180,7 +176,10 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
            self.param_group_no_params_ranks.append(no_params_ranks)
 | 
			
		||||
            self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
 | 
			
		||||
 | 
			
		||||
            # store the mapping between param to rank each param should belong to only one rank
 | 
			
		||||
            # store the mapping between param to rank each param should belong to only one rank.
 | 
			
		||||
            # we can skip the moe param and do not keep them in _param_store to save memory
 | 
			
		||||
            # (means we need to deal with moe param in a different way), but it will increase
 | 
			
		||||
            # complexity and reduce code readablity.
 | 
			
		||||
            for rank, params in enumerate(params_per_rank):
 | 
			
		||||
                # check whether any rank is not assigned params.
 | 
			
		||||
                if len(params) != 0:
 | 
			
		||||
| 
						 | 
				
			
			@ -267,26 +266,34 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
        numel_per_rank = [0 for _ in range(self._zero_world_size)]
 | 
			
		||||
        self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
 | 
			
		||||
 | 
			
		||||
        sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
 | 
			
		||||
        for i, param in enumerate(sorted_params):
 | 
			
		||||
            global_id = str(i)
 | 
			
		||||
            for j in range(len(param.size())):
 | 
			
		||||
                global_id = "_".join([global_id, str(param.size()[j])])
 | 
			
		||||
        if "moe" in param_list.keys() and param_list["moe"]:
 | 
			
		||||
            # just add current params to params_per_rank[_zero_local_rank]
 | 
			
		||||
            params_per_rank[self._zero_local_rank] = list(param_list["params"])
 | 
			
		||||
            self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
 | 
			
		||||
            no_params_ranks = list(range(self._zero_world_size))
 | 
			
		||||
            no_params_ranks.pop(self._zero_world_size)
 | 
			
		||||
 | 
			
		||||
            rank_to_go = numel_per_rank.index(min(numel_per_rank))
 | 
			
		||||
            params_per_rank[rank_to_go].append(param)
 | 
			
		||||
            self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
 | 
			
		||||
            numel_per_rank[rank_to_go] += param.numel()
 | 
			
		||||
        else:
 | 
			
		||||
            sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
 | 
			
		||||
            for i, param in enumerate(sorted_params):
 | 
			
		||||
                global_id = str(i)
 | 
			
		||||
                for j in range(len(param.size())):
 | 
			
		||||
                    global_id = "_".join([global_id, str(param.size()[j])])
 | 
			
		||||
 | 
			
		||||
        # check whether any rank is not assigned to parameters.
 | 
			
		||||
        for rank, params in enumerate(params_per_rank):
 | 
			
		||||
            if len(params) == 0:
 | 
			
		||||
                no_params_ranks.append(rank)
 | 
			
		||||
                rank_to_go = numel_per_rank.index(min(numel_per_rank))
 | 
			
		||||
                params_per_rank[rank_to_go].append(param)
 | 
			
		||||
                self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
 | 
			
		||||
                numel_per_rank[rank_to_go] += param.numel()
 | 
			
		||||
 | 
			
		||||
        if gpc.is_rank_for_log():
 | 
			
		||||
            logger.info(  # pylint: disable=W1203
 | 
			
		||||
                f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
 | 
			
		||||
            )
 | 
			
		||||
            # check whether any rank is not assigned to parameters.
 | 
			
		||||
            for rank, params in enumerate(params_per_rank):
 | 
			
		||||
                if len(params) == 0:
 | 
			
		||||
                    no_params_ranks.append(rank)
 | 
			
		||||
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.info(  # pylint: disable=W1203
 | 
			
		||||
                    f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        return params_per_rank, set(no_params_ranks)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -296,6 +303,7 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
        for group_id in range(self.num_param_groups):
 | 
			
		||||
            param_group = self._fp16_param_groups[group_id]
 | 
			
		||||
            for param in param_group:
 | 
			
		||||
                # we should not reduce the param in moe
 | 
			
		||||
                if param.requires_grad and not is_moe_param(param):
 | 
			
		||||
                    reduce_rank = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -496,6 +504,7 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
        if not self._overlap_communication:
 | 
			
		||||
            for group_id in range(len(self._fp16_param_groups)):
 | 
			
		||||
                for param in self._fp16_param_groups[group_id]:
 | 
			
		||||
                    # we should not reduce the param in moe
 | 
			
		||||
                    if param.grad is not None and not is_moe_param(param):
 | 
			
		||||
                        self._store_and_try_reduce_grads_by_bucket(param)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue