mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			53 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			53 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
"""
 | 
						|
The file has been adapted from the following files:
 | 
						|
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
 | 
						|
 Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
 | 
						|
 We retain the following license from the original files:
 | 
						|
"""
 | 
						|
 | 
						|
# Copyright (c) Microsoft Corporation.
 | 
						|
# SPDX-License-Identifier: Apache-2.0
 | 
						|
 | 
						|
# DeepSpeed Team
 | 
						|
 | 
						|
from typing import Union, cast
 | 
						|
 | 
						|
import torch
 | 
						|
from torch.nn import Module, ModuleList
 | 
						|
 | 
						|
 | 
						|
class Experts(torch.nn.Module):
 | 
						|
    """
 | 
						|
    Local Experts.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1):
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
        # TODO: We can not deepcopy FeedForward since it contains a process_group in submodules
 | 
						|
        # self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
 | 
						|
 | 
						|
        if isinstance(experts, ModuleList):
 | 
						|
            self.experts = cast(ModuleList, experts)
 | 
						|
        else:
 | 
						|
            self.experts = ModuleList([experts])
 | 
						|
        self.num_local_experts = num_local_experts
 | 
						|
 | 
						|
        # TODO: revisit allreduce for moe.gate...
 | 
						|
        for expert in self.experts:
 | 
						|
            # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
 | 
						|
            for _, param in expert.named_parameters():
 | 
						|
                param.all_reduce = False
 | 
						|
 | 
						|
    def forward(self, inputs):
 | 
						|
        chunks = inputs.chunk(self.num_local_experts, dim=1)
 | 
						|
        expert_outputs = []
 | 
						|
        for chunk, expert in zip(chunks, self.experts):
 | 
						|
            out = expert(chunk)
 | 
						|
            if isinstance(out, tuple):
 | 
						|
                out = out[0]  # Ignore the bias term for now
 | 
						|
            expert_outputs += [out]
 | 
						|
 | 
						|
        expert_output = torch.cat(expert_outputs, dim=1)
 | 
						|
        return expert_output
 |