diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index 14b3a7ee4..2a51344c3 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -1,8 +1,9 @@ from .experts import Experts, FFNExperts, TPExperts -from .layers import MoeLayer, Top1Router, Top2Router, MoeModule +from .layers import MoeLayer, MoeModule +from .routers import MoeRouter, Top1Router, Top2Router from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts __all__ = [ 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', - 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule' + 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter' ] diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index d308c1253..259f53f1a 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -1,231 +1,17 @@ -import functools import math import torch import torch.nn as nn import torch.nn.functional as F -import torch.distributed as dist from colossalai.context.moe_context import MOE_CONTEXT from colossalai.utils import get_current_device -from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum -from .experts import MoeExperts, Experts -from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator, autocast_softmax +from colossalai.nn.layer.moe._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, \ + ReduceScatter, MoeDispatch, MoeCombine +from colossalai.nn.layer.moe.experts import MoeExperts, Experts +from colossalai.nn.layer.moe.utils import UniformNoiseGenerator, NormalNoiseGenerator +from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator -from typing import Callable, Optional, Type -from torch.distributed import ProcessGroup - - -class Top1Router(nn.Module): - """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More deailted function can be found in the paper about Switch Transformer - of Google. - - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert. - select_policy (str, optional): The policy about tokens selection. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__() - self.capacity_factor_train = capacity_factor_train - self.capacity_factor_eval = capacity_factor_eval - self.min_capacity = min_capacity - self.select_policy = select_policy - self.noisy_func = noisy_func - self.drop_tks = drop_tks - - assert select_policy in {"first", "random"} - if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, - device=get_current_device())).rsample - - def get_capacity( - self, - logits_shape, - ): - capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1]) - capacity += capacity % 2 - capacity = max(capacity, self.min_capacity) - assert capacity > 0 - return capacity - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - logits = autocast_softmax(inputs, dim=-1) - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(inputs, dim=-1) - mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - - if self.training: - me = torch.mean(logits, dim=0) - ce = torch.mean(mask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) - MOE_CONTEXT.add_loss(l_aux) - elif not self.drop_tks: - max_num = torch.max(torch.sum(mask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - else: - pass - - if self.select_policy == "random": - rand_mask = mask * self.uniform(mask.shape) - _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) - mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) - ranks = moe_cumsum(mask) - elif self.select_policy == "first": - ranks = moe_cumsum(mask) - mask = mask * torch.lt(ranks, capacity) - else: - raise NotImplementedError("Not support such select policy yet.") - - ranks = torch.sum(mask * ranks, dim=-1) - - if use_kernel: - mask = torch.sum(mask, dim=-1) - mask = torch.stack([mask], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) - return logits, mask, dest_idx, num_experts * capacity - else: - ranks = F.one_hot(ranks, num_classes=capacity) - weight = mask * logits.type_as(inputs) - combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) - sec_mask = combine_weights.bool() - return combine_weights, sec_mask - - -class Top2Router(nn.Module): - """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More deailted function can be found in the paper about ViT-MoE. - - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation. - """ - - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__() - self.capacity_factor_train = capacity_factor_train - self.capacity_factor_eval = capacity_factor_eval - self.min_capacity = min_capacity - self.noisy_func = noisy_func - self.drop_tks = drop_tks - - def get_capacity( - self, - logits_shape, - ): - capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1]) - capacity += capacity % 2 - capacity = max(capacity, self.min_capacity) - assert capacity > 0 - return capacity - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - # inputs: [s, h] - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - logits = autocast_softmax(inputs, dim=-1) # logits: [s, e] - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(logits, dim=-1) - mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) - top2_idx = torch.argmax(logits_except1, dim=-1) - mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - - cmask = (mask1 + mask2) # loss: [s, e] - if self.training: - me = torch.mean(logits, dim=0) - ce = torch.mean(cmask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 - MOE_CONTEXT.add_loss(l_aux) - elif not self.drop_tks: - max_num = torch.max(torch.sum(cmask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - else: - pass - - rank1 = moe_cumsum(mask1) # rank1: [s, e] - rank2 = moe_cumsum(mask2) - rank2 += torch.sum(mask1, dim=-2, keepdim=True) - - mask1 *= torch.lt(rank1, capacity) - mask2 *= torch.lt(rank2, capacity) - - rank1 = torch.sum(mask1 * rank1, dim=-1) - rank2 = torch.sum(mask2 * rank2, dim=-1) - - if use_kernel: - mask1 = torch.sum(mask1, dim=-1) - mask2 = torch.sum(mask2, dim=-1) - - mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) - - return logits, mask, dest_idx, num_experts * capacity - else: - weight1 = mask1 * logits.type_as(inputs) - weight2 = mask2 * logits.type_as(inputs) - rank1_sc = F.one_hot(rank1, num_classes=capacity) - rank2_sc = F.one_hot(rank2, num_classes=capacity) - - cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) - cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) - cb_weight = cb_weight1 + cb_weight2 - sec_mask = cb_weight.bool() - - return cb_weight, sec_mask - - -class FP32LinearGate(nn.Module): - """Gate module used in MOE layer. Just a linear function without bias. - But it should be kept as fp32 forever. - - Args: - d_model (int): Hidden dimension of training model - num_experts (int): The number experts - - Attributes: - weight (ForceFP32Parameter): The weight of linear gate - """ - - def __init__(self, d_model: int, num_experts: int, scale: float = 0.1): - super().__init__() - self.weight = ForceFP32Parameter(torch.empty(num_experts, d_model, device=get_current_device())) - nn.init.trunc_normal_(self.weight, std=math.sqrt(scale / d_model)) - - def forward(self, x: torch.Tensor): - return F.linear(x, self.weight) +from typing import Optional, Type, Tuple @no_shard_zero_decrator(is_replicated=True) @@ -238,17 +24,17 @@ class MoeLayer(nn.Module): Args: dim_model (int): Dimension of model. num_experts (int): The number of experts. - router (:class:`torch.nn.Module`): Instance of router used in routing. - experts (:class:`torch.nn.Module`): Instance of experts generated by Expert. + router (MoeRouter): Instance of router used in routing. + experts (MoeExperts): Instance of experts generated by Expert. """ - def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts): + def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts): super().__init__() self.d_model = dim_model self.num_experts = num_experts self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) - self.router = router - self.experts = experts + self.router: MoeRouter = router + self.experts: MoeExperts = experts self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False self.ep_group = experts.dist_info.ep_group self.ep_size = experts.dist_info.ep_size @@ -271,7 +57,7 @@ class MoeLayer(nn.Module): expert_out = ReduceScatter.apply(expert_out, self.ep_group) return expert_out - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + def forward(self, inputs: torch.Tensor) -> Tuple: # reshape the input tokens tokens = inputs.reshape(-1, self.d_model) @@ -309,7 +95,8 @@ class MoeLayer(nn.Module): ans = torch.matmul(combine_weights, expert_output) ans = ans.reshape(inputs.shape) - return ans + l_aux = self.router.pop_routing_loss() + return ans, l_aux class MoeModule(nn.Module): @@ -403,7 +190,7 @@ class MoeModule(nn.Module): experts=self.experts) def forward(self, inputs: torch.Tensor): - moe_output = self.moe_layer(inputs) + moe_output, l_aux = self.moe_layer(inputs) if self.use_residual: residual_output = self.residual_module(inputs) @@ -413,4 +200,4 @@ class MoeModule(nn.Module): else: output = moe_output - return output + return output, l_aux diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py new file mode 100644 index 000000000..c522c655a --- /dev/null +++ b/colossalai/nn/layer/moe/routers.py @@ -0,0 +1,226 @@ +import math +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from colossalai.utils import get_current_device +from colossalai.context import MOE_CONTEXT +from colossalai.nn.layer.moe._operation import moe_cumsum +from typing import Callable, Optional +from torch.distributed import ProcessGroup + + +class MoeRouter(nn.Module, ABC): + """Base class for all MoE routers. + Args: + k_value (int): The value of top_k. + capacity_factor_train (float): Capacity factor in routing of training. + capacity_factor_eval (float): Capacity factor in routing of evaluation. + min_capacity (int): The minimum number of the capacity of each expert. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__(self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__() + self.k_value = k_value + self.capacity_factor_train = capacity_factor_train + self.capacity_factor_eval = capacity_factor_eval + self.min_capacity = min_capacity + self.noisy_func = noisy_func + self.drop_tks = drop_tks + self._routing_loss = None + + def get_capacity(self, logits_shape): + capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval + capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity += capacity % 2 + capacity = max(capacity, self.min_capacity) + assert capacity > 0 + return capacity + + def set_routing_loss(self, aux_loss: torch.Tensor) -> None: + assert self._routing_loss is None + self._routing_loss = aux_loss + + def pop_routing_loss(self) -> torch.Tensor: + assert self._routing_loss is not None + reservation = self._routing_loss + self._routing_loss = None + return reservation + + +class Top1Router(MoeRouter): + """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More deailted function can be found in the paper about Switch Transformer + of Google. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert. + select_policy (str, optional): The policy about tokens selection. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__(k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + self.select_policy = select_policy + assert select_policy in {"first", "random"} + if select_policy == "random": + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, + device=get_current_device())).rsample + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(inputs, dim=-1) + mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(mask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(mask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + if self.select_policy == "random": + rand_mask = mask * self.uniform(mask.shape) + _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) + mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) + ranks = moe_cumsum(mask) + elif self.select_policy == "first": + ranks = moe_cumsum(mask) + mask = mask * torch.lt(ranks, capacity) + else: + raise NotImplementedError("Not support such select policy yet.") + + ranks = torch.sum(mask * ranks, dim=-1) + + if use_kernel: + mask = torch.sum(mask, dim=-1) + mask = torch.stack([mask], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) + return logits, mask, dest_idx, num_experts * capacity + else: + ranks = F.one_hot(ranks, num_classes=capacity) + weight = mask * logits.type_as(inputs) + combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) + sec_mask = combine_weights.bool() + return combine_weights, sec_mask + + +class Top2Router(MoeRouter): + """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More deailted function can be found in the paper about ViT-MoE. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation. + """ + + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__(k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + # inputs: [s, h] + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) # logits: [s, e] + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(logits, dim=-1) + mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) + top2_idx = torch.argmax(logits_except1, dim=-1) + mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) + + cmask = (mask1 + mask2) # loss: [s, e] + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(cmask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(cmask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + rank1 = moe_cumsum(mask1) # rank1: [s, e] + rank2 = moe_cumsum(mask2) + rank2 += torch.sum(mask1, dim=-2, keepdim=True) + + mask1 *= torch.lt(rank1, capacity) + mask2 *= torch.lt(rank2, capacity) + + rank1 = torch.sum(mask1 * rank1, dim=-1) + rank2 = torch.sum(mask2 * rank2, dim=-1) + + if use_kernel: + mask1 = torch.sum(mask1, dim=-1) + mask2 = torch.sum(mask2, dim=-1) + + mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) + + return logits, mask, dest_idx, num_experts * capacity + else: + weight1 = mask1 * logits.type_as(inputs) + weight2 = mask2 * logits.type_as(inputs) + rank1_sc = F.one_hot(rank1, num_classes=capacity) + rank2_sc = F.one_hot(rank2, num_classes=capacity) + + cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) + cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) + cb_weight = cb_weight1 + cb_weight2 + sec_mask = cb_weight.bool() + + return cb_weight, sec_mask diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index b2770f64d..e7b9a5527 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -32,7 +32,7 @@ def run_test(rank, world_size, port): moe_layer = MoeLayer(DIM, num_experts, router, exp) layer_list.append(moe_layer) - model = nn.Sequential(*layer_list) + model = nn.ModuleList(layer_list) model = model.to(get_current_device()) sync_moe_model_param(model) @@ -49,8 +49,9 @@ def run_test(rank, world_size, port): grad = torch.randn_like(data) MOE_CONTEXT.reset_loss() - outputs = model(data) - outputs.backward(grad) + for layer in layer_list: + data, _ = layer(data) + data.backward(grad) grad_handler.handle_gradient() assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group) diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index bd87a3f58..62f924164 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -44,7 +44,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f # use matrix multiplication instead of COL_MOE_KERNL in MOE dispatch and combine layer.use_kernel = False - old_out = layer(tokens) + old_out, _ = layer(tokens) ech = old_out.shape grad = torch.randn(ech, device=get_current_device()) old_out.backward(grad) # get gradient @@ -58,7 +58,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer.gate_weight.grad.zero_() layer.use_kernel = True - new_out = layer(tokens) # get ouputs through colossal kernel + new_out, _ = layer(tokens) # get ouputs through colossal kernel if data_type == torch.float32: check_equal(old_out, new_out) diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index b5746f562..04dc9c514 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -19,20 +19,39 @@ from colossalai.utils import get_current_device from tests.test_zero.common import CONFIG -class MoeModel(CheckpointModule): +class MoeModel(nn.Module): def __init__(self, checkpoint: bool = False): - super().__init__(checkpoint) - self.proj1 = nn.Linear(4, 16) - expert_cls = nn.Linear - expert_args_dict = dict(in_features=16, out_features=16) - self.moe = MoeModule(dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict) - self.proj2 = nn.Linear(16, 4) + + class TestSubModule(CheckpointModule): + + def __init__(self): + super().__init__(checkpoint) + expert_cls = nn.Linear + expert_args_dict = dict(in_features=16, out_features=16) + self.moe = MoeModule(dim_model=16, + num_experts=8, + use_residual=True, + expert_cls=expert_cls, + **expert_args_dict) + self.proj = nn.Linear(16, 4) + + def _forward(self, x): + x, y = self.moe(x) + x = self.proj(x) + return x, y + + super().__init__() + self.test_embed = nn.Linear(4, 16) + self.test_transform = TestSubModule() def forward(self, x): - x = self.proj1(x) - x = self.moe(x) - x = self.proj2(x) + MOE_CONTEXT.reset_loss() + + x = self.test_embed(x) + x, y = self.test_transform(x) + + MOE_CONTEXT.add_loss(y) return x diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index 778bf6d26..37e8a4bab 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -4,6 +4,8 @@ import colossalai import pytest import torch import torch.multiprocessing as mp + +from colossalai.nn import MoeLoss from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext @@ -26,7 +28,8 @@ def run_model_test(enable_autocast, shard_strategy_class): shard_strategy = shard_strategy_class() get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module') - _, train_dataloader, _, _, criterion = get_components_func() + _, train_dataloader, _, optimizer_class, _ = get_components_func() + criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), shard_strategy=shard_strategy, @@ -59,7 +62,6 @@ def run_model_test(enable_autocast, shard_strategy_class): def run_dist(rank, world_size, port): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') MOE_CONTEXT.setup(seed=42) - MOE_CONTEXT.reset_loss() run_model_test() diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index afc6ba5f7..da67b7610 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -5,6 +5,7 @@ import pytest import torch import torch.multiprocessing as mp from colossalai.amp import convert_to_apex_amp +from colossalai.nn import MoeLoss from colossalai.nn.optimizer import CPUAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port @@ -60,7 +61,8 @@ def _run_test_sharded_optim_v2(cpu_offload, return MOE_CONTEXT.reset_loss() get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module') - _, train_dataloader, _, optimizer_class, criterion = get_components_func() + _, train_dataloader, _, optimizer_class, _ = get_components_func() + criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(), shard_strategy=shard_strategy,