mirror of https://github.com/hpcaitech/ColossalAI
[moe] fix moe bugs (#1633)
parent
702dbc5288
commit
a088022efc
|
@ -1,8 +1,9 @@
|
|||
from .experts import Experts, FFNExperts, TPExperts
|
||||
from .layers import MoeLayer, Top1Router, Top2Router, MoeModule
|
||||
from .layers import MoeLayer, MoeModule
|
||||
from .routers import MoeRouter, Top1Router, Top2Router
|
||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
|
||||
|
||||
__all__ = [
|
||||
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
|
||||
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule'
|
||||
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter'
|
||||
]
|
||||
|
|
|
@ -1,231 +1,17 @@
|
|||
import functools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
||||
from .experts import MoeExperts, Experts
|
||||
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator, autocast_softmax
|
||||
from colossalai.nn.layer.moe._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, \
|
||||
ReduceScatter, MoeDispatch, MoeCombine
|
||||
from colossalai.nn.layer.moe.experts import MoeExperts, Experts
|
||||
from colossalai.nn.layer.moe.utils import UniformNoiseGenerator, NormalNoiseGenerator
|
||||
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
|
||||
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
|
||||
from typing import Callable, Optional, Type
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class Top1Router(nn.Module):
|
||||
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||
for routing usage. More deailted function can be found in the paper about Switch Transformer
|
||||
of Google.
|
||||
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert.
|
||||
select_policy (str, optional): The policy about tokens selection.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
select_policy: str = "first",
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__()
|
||||
self.capacity_factor_train = capacity_factor_train
|
||||
self.capacity_factor_eval = capacity_factor_eval
|
||||
self.min_capacity = min_capacity
|
||||
self.select_policy = select_policy
|
||||
self.noisy_func = noisy_func
|
||||
self.drop_tks = drop_tks
|
||||
|
||||
assert select_policy in {"first", "random"}
|
||||
if select_policy == "random":
|
||||
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
|
||||
high=torch.tensor(1.0,
|
||||
device=get_current_device())).rsample
|
||||
|
||||
def get_capacity(
|
||||
self,
|
||||
logits_shape,
|
||||
):
|
||||
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
||||
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
|
||||
capacity += capacity % 2
|
||||
capacity = max(capacity, self.min_capacity)
|
||||
assert capacity > 0
|
||||
return capacity
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
logits = autocast_softmax(inputs, dim=-1)
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
top1_idx = torch.argmax(inputs, dim=-1)
|
||||
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
if self.training:
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(mask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce)
|
||||
MOE_CONTEXT.add_loss(l_aux)
|
||||
elif not self.drop_tks:
|
||||
max_num = torch.max(torch.sum(mask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
else:
|
||||
pass
|
||||
|
||||
if self.select_policy == "random":
|
||||
rand_mask = mask * self.uniform(mask.shape)
|
||||
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
||||
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
|
||||
ranks = moe_cumsum(mask)
|
||||
elif self.select_policy == "first":
|
||||
ranks = moe_cumsum(mask)
|
||||
mask = mask * torch.lt(ranks, capacity)
|
||||
else:
|
||||
raise NotImplementedError("Not support such select policy yet.")
|
||||
|
||||
ranks = torch.sum(mask * ranks, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
mask = torch.stack([mask], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
||||
return logits, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
ranks = F.one_hot(ranks, num_classes=capacity)
|
||||
weight = mask * logits.type_as(inputs)
|
||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||
sec_mask = combine_weights.bool()
|
||||
return combine_weights, sec_mask
|
||||
|
||||
|
||||
class Top2Router(nn.Module):
|
||||
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||
for routing usage. More deailted function can be found in the paper about ViT-MoE.
|
||||
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__()
|
||||
self.capacity_factor_train = capacity_factor_train
|
||||
self.capacity_factor_eval = capacity_factor_eval
|
||||
self.min_capacity = min_capacity
|
||||
self.noisy_func = noisy_func
|
||||
self.drop_tks = drop_tks
|
||||
|
||||
def get_capacity(
|
||||
self,
|
||||
logits_shape,
|
||||
):
|
||||
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
||||
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
|
||||
capacity += capacity % 2
|
||||
capacity = max(capacity, self.min_capacity)
|
||||
assert capacity > 0
|
||||
return capacity
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||
# inputs: [s, h]
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
top1_idx = torch.argmax(logits, dim=-1)
|
||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
|
||||
top2_idx = torch.argmax(logits_except1, dim=-1)
|
||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
cmask = (mask1 + mask2) # loss: [s, e]
|
||||
if self.training:
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(cmask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
|
||||
MOE_CONTEXT.add_loss(l_aux)
|
||||
elif not self.drop_tks:
|
||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
else:
|
||||
pass
|
||||
|
||||
rank1 = moe_cumsum(mask1) # rank1: [s, e]
|
||||
rank2 = moe_cumsum(mask2)
|
||||
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
||||
|
||||
mask1 *= torch.lt(rank1, capacity)
|
||||
mask2 *= torch.lt(rank2, capacity)
|
||||
|
||||
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
||||
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask1 = torch.sum(mask1, dim=-1)
|
||||
mask2 = torch.sum(mask2, dim=-1)
|
||||
|
||||
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
|
||||
|
||||
return logits, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
weight1 = mask1 * logits.type_as(inputs)
|
||||
weight2 = mask2 * logits.type_as(inputs)
|
||||
rank1_sc = F.one_hot(rank1, num_classes=capacity)
|
||||
rank2_sc = F.one_hot(rank2, num_classes=capacity)
|
||||
|
||||
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
|
||||
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
|
||||
cb_weight = cb_weight1 + cb_weight2
|
||||
sec_mask = cb_weight.bool()
|
||||
|
||||
return cb_weight, sec_mask
|
||||
|
||||
|
||||
class FP32LinearGate(nn.Module):
|
||||
"""Gate module used in MOE layer. Just a linear function without bias.
|
||||
But it should be kept as fp32 forever.
|
||||
|
||||
Args:
|
||||
d_model (int): Hidden dimension of training model
|
||||
num_experts (int): The number experts
|
||||
|
||||
Attributes:
|
||||
weight (ForceFP32Parameter): The weight of linear gate
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, num_experts: int, scale: float = 0.1):
|
||||
super().__init__()
|
||||
self.weight = ForceFP32Parameter(torch.empty(num_experts, d_model, device=get_current_device()))
|
||||
nn.init.trunc_normal_(self.weight, std=math.sqrt(scale / d_model))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return F.linear(x, self.weight)
|
||||
from typing import Optional, Type, Tuple
|
||||
|
||||
|
||||
@no_shard_zero_decrator(is_replicated=True)
|
||||
|
@ -238,17 +24,17 @@ class MoeLayer(nn.Module):
|
|||
Args:
|
||||
dim_model (int): Dimension of model.
|
||||
num_experts (int): The number of experts.
|
||||
router (:class:`torch.nn.Module`): Instance of router used in routing.
|
||||
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
|
||||
router (MoeRouter): Instance of router used in routing.
|
||||
experts (MoeExperts): Instance of experts generated by Expert.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
|
||||
def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts):
|
||||
super().__init__()
|
||||
self.d_model = dim_model
|
||||
self.num_experts = num_experts
|
||||
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
|
||||
self.router = router
|
||||
self.experts = experts
|
||||
self.router: MoeRouter = router
|
||||
self.experts: MoeExperts = experts
|
||||
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
|
||||
self.ep_group = experts.dist_info.ep_group
|
||||
self.ep_size = experts.dist_info.ep_size
|
||||
|
@ -271,7 +57,7 @@ class MoeLayer(nn.Module):
|
|||
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
|
||||
return expert_out
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, inputs: torch.Tensor) -> Tuple:
|
||||
# reshape the input tokens
|
||||
tokens = inputs.reshape(-1, self.d_model)
|
||||
|
||||
|
@ -309,7 +95,8 @@ class MoeLayer(nn.Module):
|
|||
ans = torch.matmul(combine_weights, expert_output)
|
||||
|
||||
ans = ans.reshape(inputs.shape)
|
||||
return ans
|
||||
l_aux = self.router.pop_routing_loss()
|
||||
return ans, l_aux
|
||||
|
||||
|
||||
class MoeModule(nn.Module):
|
||||
|
@ -403,7 +190,7 @@ class MoeModule(nn.Module):
|
|||
experts=self.experts)
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
moe_output = self.moe_layer(inputs)
|
||||
moe_output, l_aux = self.moe_layer(inputs)
|
||||
|
||||
if self.use_residual:
|
||||
residual_output = self.residual_module(inputs)
|
||||
|
@ -413,4 +200,4 @@ class MoeModule(nn.Module):
|
|||
else:
|
||||
output = moe_output
|
||||
|
||||
return output
|
||||
return output, l_aux
|
||||
|
|
|
@ -0,0 +1,226 @@
|
|||
import math
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.context import MOE_CONTEXT
|
||||
from colossalai.nn.layer.moe._operation import moe_cumsum
|
||||
from typing import Callable, Optional
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class MoeRouter(nn.Module, ABC):
|
||||
"""Base class for all MoE routers.
|
||||
Args:
|
||||
k_value (int): The value of top_k.
|
||||
capacity_factor_train (float): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float): Capacity factor in routing of evaluation.
|
||||
min_capacity (int): The minimum number of the capacity of each expert.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
k_value: int,
|
||||
capacity_factor_train: float,
|
||||
capacity_factor_eval: float,
|
||||
min_capacity: int,
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__()
|
||||
self.k_value = k_value
|
||||
self.capacity_factor_train = capacity_factor_train
|
||||
self.capacity_factor_eval = capacity_factor_eval
|
||||
self.min_capacity = min_capacity
|
||||
self.noisy_func = noisy_func
|
||||
self.drop_tks = drop_tks
|
||||
self._routing_loss = None
|
||||
|
||||
def get_capacity(self, logits_shape):
|
||||
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
||||
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
|
||||
capacity += capacity % 2
|
||||
capacity = max(capacity, self.min_capacity)
|
||||
assert capacity > 0
|
||||
return capacity
|
||||
|
||||
def set_routing_loss(self, aux_loss: torch.Tensor) -> None:
|
||||
assert self._routing_loss is None
|
||||
self._routing_loss = aux_loss
|
||||
|
||||
def pop_routing_loss(self) -> torch.Tensor:
|
||||
assert self._routing_loss is not None
|
||||
reservation = self._routing_loss
|
||||
self._routing_loss = None
|
||||
return reservation
|
||||
|
||||
|
||||
class Top1Router(MoeRouter):
|
||||
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||
for routing usage. More deailted function can be found in the paper about Switch Transformer
|
||||
of Google.
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert.
|
||||
select_policy (str, optional): The policy about tokens selection.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
select_policy: str = "first",
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__(k_value=1,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
self.select_policy = select_policy
|
||||
assert select_policy in {"first", "random"}
|
||||
if select_policy == "random":
|
||||
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
|
||||
high=torch.tensor(1.0,
|
||||
device=get_current_device())).rsample
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
logits = F.softmax(inputs, dim=-1)
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
top1_idx = torch.argmax(inputs, dim=-1)
|
||||
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
# caculate the auxiliary loss
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(mask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce)
|
||||
self.set_routing_loss(l_aux)
|
||||
|
||||
if not self.training and not self.drop_tks:
|
||||
max_num = torch.max(torch.sum(mask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
if self.select_policy == "random":
|
||||
rand_mask = mask * self.uniform(mask.shape)
|
||||
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
||||
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
|
||||
ranks = moe_cumsum(mask)
|
||||
elif self.select_policy == "first":
|
||||
ranks = moe_cumsum(mask)
|
||||
mask = mask * torch.lt(ranks, capacity)
|
||||
else:
|
||||
raise NotImplementedError("Not support such select policy yet.")
|
||||
|
||||
ranks = torch.sum(mask * ranks, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
mask = torch.stack([mask], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
||||
return logits, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
ranks = F.one_hot(ranks, num_classes=capacity)
|
||||
weight = mask * logits.type_as(inputs)
|
||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||
sec_mask = combine_weights.bool()
|
||||
return combine_weights, sec_mask
|
||||
|
||||
|
||||
class Top2Router(MoeRouter):
|
||||
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||
for routing usage. More deailted function can be found in the paper about ViT-MoE.
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__(k_value=2,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||
# inputs: [s, h]
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
logits = F.softmax(inputs, dim=-1) # logits: [s, e]
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
top1_idx = torch.argmax(logits, dim=-1)
|
||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
|
||||
top2_idx = torch.argmax(logits_except1, dim=-1)
|
||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
cmask = (mask1 + mask2) # loss: [s, e]
|
||||
|
||||
# caculate the auxiliary loss
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(cmask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
|
||||
self.set_routing_loss(l_aux)
|
||||
|
||||
if not self.training and not self.drop_tks:
|
||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
rank1 = moe_cumsum(mask1) # rank1: [s, e]
|
||||
rank2 = moe_cumsum(mask2)
|
||||
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
||||
|
||||
mask1 *= torch.lt(rank1, capacity)
|
||||
mask2 *= torch.lt(rank2, capacity)
|
||||
|
||||
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
||||
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask1 = torch.sum(mask1, dim=-1)
|
||||
mask2 = torch.sum(mask2, dim=-1)
|
||||
|
||||
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
|
||||
|
||||
return logits, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
weight1 = mask1 * logits.type_as(inputs)
|
||||
weight2 = mask2 * logits.type_as(inputs)
|
||||
rank1_sc = F.one_hot(rank1, num_classes=capacity)
|
||||
rank2_sc = F.one_hot(rank2, num_classes=capacity)
|
||||
|
||||
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
|
||||
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
|
||||
cb_weight = cb_weight1 + cb_weight2
|
||||
sec_mask = cb_weight.bool()
|
||||
|
||||
return cb_weight, sec_mask
|
|
@ -32,7 +32,7 @@ def run_test(rank, world_size, port):
|
|||
moe_layer = MoeLayer(DIM, num_experts, router, exp)
|
||||
layer_list.append(moe_layer)
|
||||
|
||||
model = nn.Sequential(*layer_list)
|
||||
model = nn.ModuleList(layer_list)
|
||||
model = model.to(get_current_device())
|
||||
sync_moe_model_param(model)
|
||||
|
||||
|
@ -49,8 +49,9 @@ def run_test(rank, world_size, port):
|
|||
grad = torch.randn_like(data)
|
||||
|
||||
MOE_CONTEXT.reset_loss()
|
||||
outputs = model(data)
|
||||
outputs.backward(grad)
|
||||
for layer in layer_list:
|
||||
data, _ = layer(data)
|
||||
data.backward(grad)
|
||||
grad_handler.handle_gradient()
|
||||
|
||||
assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group)
|
||||
|
|
|
@ -44,7 +44,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
|||
|
||||
# use matrix multiplication instead of COL_MOE_KERNL in MOE dispatch and combine
|
||||
layer.use_kernel = False
|
||||
old_out = layer(tokens)
|
||||
old_out, _ = layer(tokens)
|
||||
ech = old_out.shape
|
||||
grad = torch.randn(ech, device=get_current_device())
|
||||
old_out.backward(grad) # get gradient
|
||||
|
@ -58,7 +58,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
|||
layer.gate_weight.grad.zero_()
|
||||
|
||||
layer.use_kernel = True
|
||||
new_out = layer(tokens) # get ouputs through colossal kernel
|
||||
new_out, _ = layer(tokens) # get ouputs through colossal kernel
|
||||
|
||||
if data_type == torch.float32:
|
||||
check_equal(old_out, new_out)
|
||||
|
|
|
@ -19,20 +19,39 @@ from colossalai.utils import get_current_device
|
|||
from tests.test_zero.common import CONFIG
|
||||
|
||||
|
||||
class MoeModel(CheckpointModule):
|
||||
class MoeModel(nn.Module):
|
||||
|
||||
def __init__(self, checkpoint: bool = False):
|
||||
|
||||
class TestSubModule(CheckpointModule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(checkpoint)
|
||||
self.proj1 = nn.Linear(4, 16)
|
||||
expert_cls = nn.Linear
|
||||
expert_args_dict = dict(in_features=16, out_features=16)
|
||||
self.moe = MoeModule(dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict)
|
||||
self.proj2 = nn.Linear(16, 4)
|
||||
self.moe = MoeModule(dim_model=16,
|
||||
num_experts=8,
|
||||
use_residual=True,
|
||||
expert_cls=expert_cls,
|
||||
**expert_args_dict)
|
||||
self.proj = nn.Linear(16, 4)
|
||||
|
||||
def _forward(self, x):
|
||||
x, y = self.moe(x)
|
||||
x = self.proj(x)
|
||||
return x, y
|
||||
|
||||
super().__init__()
|
||||
self.test_embed = nn.Linear(4, 16)
|
||||
self.test_transform = TestSubModule()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj1(x)
|
||||
x = self.moe(x)
|
||||
x = self.proj2(x)
|
||||
MOE_CONTEXT.reset_loss()
|
||||
|
||||
x = self.test_embed(x)
|
||||
x, y = self.test_transform(x)
|
||||
|
||||
MOE_CONTEXT.add_loss(y)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@ import colossalai
|
|||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.nn import MoeLoss
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
|
@ -26,7 +28,8 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
|||
shard_strategy = shard_strategy_class()
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
|
||||
_, train_dataloader, _, _, criterion = get_components_func()
|
||||
_, train_dataloader, _, optimizer_class, _ = get_components_func()
|
||||
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
|
||||
|
||||
with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
|
||||
shard_strategy=shard_strategy,
|
||||
|
@ -59,7 +62,6 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
|||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
MOE_CONTEXT.setup(seed=42)
|
||||
MOE_CONTEXT.reset_loss()
|
||||
run_model_test()
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import pytest
|
|||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn import MoeLoss
|
||||
from colossalai.nn.optimizer import CPUAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
|
@ -60,7 +61,8 @@ def _run_test_sharded_optim_v2(cpu_offload,
|
|||
return
|
||||
MOE_CONTEXT.reset_loss()
|
||||
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
|
||||
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
_, train_dataloader, _, optimizer_class, _ = get_components_func()
|
||||
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
|
||||
|
||||
with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
|
|
Loading…
Reference in New Issue