mirror of https://github.com/InternLM/InternLM
48 lines
1.7 KiB
Python
48 lines
1.7 KiB
Python
"""
|
|
The file has been adapted from the following files:
|
|
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
|
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
|
We retain the following license from the original files:
|
|
"""
|
|
from typing import Union, cast
|
|
|
|
import torch
|
|
from torch.nn import Module, ModuleList
|
|
|
|
|
|
class Experts(torch.nn.Module):
|
|
"""
|
|
Local Experts.
|
|
"""
|
|
|
|
def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1, expert_group_name=None):
|
|
super().__init__()
|
|
|
|
# TODO: We can not deepcopy FeedForward since it contains a process_group in submodules
|
|
# self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
|
|
|
|
if isinstance(experts, ModuleList):
|
|
self.experts = cast(ModuleList, experts)
|
|
else:
|
|
self.experts = ModuleList([experts])
|
|
self.num_local_experts = num_local_experts
|
|
|
|
# TODO: revisit allreduce for moe.gate...
|
|
for expert in self.experts:
|
|
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
|
|
for _, param in expert.named_parameters():
|
|
param.is_expert = True
|
|
param.group_name = expert_group_name
|
|
|
|
def forward(self, inputs):
|
|
chunks = inputs.chunk(self.num_local_experts, dim=1)
|
|
expert_outputs = []
|
|
for chunk, expert in zip(chunks, self.experts):
|
|
out = expert(chunk)
|
|
if isinstance(out, tuple):
|
|
out = out[0] # Ignore the bias term for now
|
|
expert_outputs += [out]
|
|
|
|
expert_output = torch.cat(expert_outputs, dim=1)
|
|
return expert_output
|