mirror of https://github.com/InternLM/InternLM
				
				
				
			fix more moe bugs in zero optimizer
							parent
							
								
									9ee57e6c8a
								
							
						
					
					
						commit
						48ed81d2e6
					
				| 
						 | 
				
			
			@ -172,7 +172,7 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
            self._fp16_param_groups[group_id] = group_params
 | 
			
		||||
 | 
			
		||||
            # assign parameters to ranks the params in the list are sorted
 | 
			
		||||
            params_per_rank, no_params_ranks = self._partition_param_list(group_params)
 | 
			
		||||
            params_per_rank, no_params_ranks = self._partition_param_list(param_group)
 | 
			
		||||
            self.param_group_no_params_ranks.append(no_params_ranks)
 | 
			
		||||
            self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -260,15 +260,22 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
    def num_param_groups(self):
 | 
			
		||||
        return len(self._fp16_param_groups)
 | 
			
		||||
 | 
			
		||||
    def _partition_param_list(self, param_list):
 | 
			
		||||
    def _get_real_dp_process_group(self, param_groups):
 | 
			
		||||
        if "moe" in param_groups.keys() and param_groups["moe"]:
 | 
			
		||||
            return ParallelMode.EXPERT_DATA
 | 
			
		||||
        else:
 | 
			
		||||
            return ParallelMode.DATA
 | 
			
		||||
 | 
			
		||||
    def _partition_param_list(self, param_group):
 | 
			
		||||
        no_params_ranks = []
 | 
			
		||||
        params_per_rank = [[] for _ in range(self._zero_world_size)]
 | 
			
		||||
        numel_per_rank = [0 for _ in range(self._zero_world_size)]
 | 
			
		||||
        self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
 | 
			
		||||
        param_list = param_group["params"]
 | 
			
		||||
 | 
			
		||||
        if "moe" in param_list.keys() and param_list["moe"]:
 | 
			
		||||
        if "moe" in param_group.keys() and param_group["moe"]:
 | 
			
		||||
            # just add current params to params_per_rank[_zero_local_rank]
 | 
			
		||||
            params_per_rank[self._zero_local_rank] = list(param_list["params"])
 | 
			
		||||
            params_per_rank[self._zero_local_rank] = list(param_list)
 | 
			
		||||
            self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
 | 
			
		||||
            no_params_ranks = list(range(self._zero_world_size))
 | 
			
		||||
            no_params_ranks.pop(self._zero_world_size)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue