[moe] fix moe bugs (#1633)

pull/1635/head^2
HELSON 2022-09-23 15:33:57 +08:00 committed by GitHub
parent 702dbc5288
commit a088022efc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 287 additions and 249 deletions

View File

@ -1,8 +1,9 @@
from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, Top1Router, Top2Router, MoeModule
from .layers import MoeLayer, MoeModule
from .routers import MoeRouter, Top1Router, Top2Router
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
__all__ = [
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule'
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter'
]

View File

@ -1,231 +1,17 @@
import functools
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts, Experts
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator, autocast_softmax
from colossalai.nn.layer.moe._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, \
ReduceScatter, MoeDispatch, MoeCombine
from colossalai.nn.layer.moe.experts import MoeExperts, Experts
from colossalai.nn.layer.moe.utils import UniformNoiseGenerator, NormalNoiseGenerator
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
from typing import Callable, Optional, Type
from torch.distributed import ProcessGroup
class Top1Router(nn.Module):
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about Switch Transformer
of Google.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert.
select_policy (str, optional): The policy about tokens selection.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Callable = None,
drop_tks: bool = True):
super().__init__()
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity
self.select_policy = select_policy
self.noisy_func = noisy_func
self.drop_tks = drop_tks
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0,
device=get_current_device())).rsample
def get_capacity(
self,
logits_shape,
):
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return capacity
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
logits = autocast_softmax(inputs, dim=-1)
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
if self.training:
me = torch.mean(logits, dim=0)
ce = torch.mean(mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce)
MOE_CONTEXT.add_loss(l_aux)
elif not self.drop_tks:
max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
else:
pass
if self.select_policy == "random":
rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
ranks = moe_cumsum(mask)
elif self.select_policy == "first":
ranks = moe_cumsum(mask)
mask = mask * torch.lt(ranks, capacity)
else:
raise NotImplementedError("Not support such select policy yet.")
ranks = torch.sum(mask * ranks, dim=-1)
if use_kernel:
mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
return logits, mask, dest_idx, num_experts * capacity
else:
ranks = F.one_hot(ranks, num_classes=capacity)
weight = mask * logits.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return combine_weights, sec_mask
class Top2Router(nn.Module):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about ViT-MoE.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Callable = None,
drop_tks: bool = True):
super().__init__()
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity
self.noisy_func = noisy_func
self.drop_tks = drop_tks
def get_capacity(
self,
logits_shape,
):
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return capacity
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
# inputs: [s, h]
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
top1_idx = torch.argmax(logits, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = (mask1 + mask2) # loss: [s, e]
if self.training:
me = torch.mean(logits, dim=0)
ce = torch.mean(cmask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
MOE_CONTEXT.add_loss(l_aux)
elif not self.drop_tks:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
else:
pass
rank1 = moe_cumsum(mask1) # rank1: [s, e]
rank2 = moe_cumsum(mask2)
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
mask1 *= torch.lt(rank1, capacity)
mask2 *= torch.lt(rank2, capacity)
rank1 = torch.sum(mask1 * rank1, dim=-1)
rank2 = torch.sum(mask2 * rank2, dim=-1)
if use_kernel:
mask1 = torch.sum(mask1, dim=-1)
mask2 = torch.sum(mask2, dim=-1)
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
return logits, mask, dest_idx, num_experts * capacity
else:
weight1 = mask1 * logits.type_as(inputs)
weight2 = mask2 * logits.type_as(inputs)
rank1_sc = F.one_hot(rank1, num_classes=capacity)
rank2_sc = F.one_hot(rank2, num_classes=capacity)
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
cb_weight = cb_weight1 + cb_weight2
sec_mask = cb_weight.bool()
return cb_weight, sec_mask
class FP32LinearGate(nn.Module):
"""Gate module used in MOE layer. Just a linear function without bias.
But it should be kept as fp32 forever.
Args:
d_model (int): Hidden dimension of training model
num_experts (int): The number experts
Attributes:
weight (ForceFP32Parameter): The weight of linear gate
"""
def __init__(self, d_model: int, num_experts: int, scale: float = 0.1):
super().__init__()
self.weight = ForceFP32Parameter(torch.empty(num_experts, d_model, device=get_current_device()))
nn.init.trunc_normal_(self.weight, std=math.sqrt(scale / d_model))
def forward(self, x: torch.Tensor):
return F.linear(x, self.weight)
from typing import Optional, Type, Tuple
@no_shard_zero_decrator(is_replicated=True)
@ -238,17 +24,17 @@ class MoeLayer(nn.Module):
Args:
dim_model (int): Dimension of model.
num_experts (int): The number of experts.
router (:class:`torch.nn.Module`): Instance of router used in routing.
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
router (MoeRouter): Instance of router used in routing.
experts (MoeExperts): Instance of experts generated by Expert.
"""
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts):
super().__init__()
self.d_model = dim_model
self.num_experts = num_experts
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
self.router = router
self.experts = experts
self.router: MoeRouter = router
self.experts: MoeExperts = experts
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
self.ep_group = experts.dist_info.ep_group
self.ep_size = experts.dist_info.ep_size
@ -271,7 +57,7 @@ class MoeLayer(nn.Module):
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
return expert_out
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor) -> Tuple:
# reshape the input tokens
tokens = inputs.reshape(-1, self.d_model)
@ -309,7 +95,8 @@ class MoeLayer(nn.Module):
ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape)
return ans
l_aux = self.router.pop_routing_loss()
return ans, l_aux
class MoeModule(nn.Module):
@ -403,7 +190,7 @@ class MoeModule(nn.Module):
experts=self.experts)
def forward(self, inputs: torch.Tensor):
moe_output = self.moe_layer(inputs)
moe_output, l_aux = self.moe_layer(inputs)
if self.use_residual:
residual_output = self.residual_module(inputs)
@ -413,4 +200,4 @@ class MoeModule(nn.Module):
else:
output = moe_output
return output
return output, l_aux

View File

@ -0,0 +1,226 @@
import math
from abc import ABC
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.context import MOE_CONTEXT
from colossalai.nn.layer.moe._operation import moe_cumsum
from typing import Callable, Optional
from torch.distributed import ProcessGroup
class MoeRouter(nn.Module, ABC):
"""Base class for all MoE routers.
Args:
k_value (int): The value of top_k.
capacity_factor_train (float): Capacity factor in routing of training.
capacity_factor_eval (float): Capacity factor in routing of evaluation.
min_capacity (int): The minimum number of the capacity of each expert.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(self,
k_value: int,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Callable = None,
drop_tks: bool = True):
super().__init__()
self.k_value = k_value
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity
self.noisy_func = noisy_func
self.drop_tks = drop_tks
self._routing_loss = None
def get_capacity(self, logits_shape):
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return capacity
def set_routing_loss(self, aux_loss: torch.Tensor) -> None:
assert self._routing_loss is None
self._routing_loss = aux_loss
def pop_routing_loss(self) -> torch.Tensor:
assert self._routing_loss is not None
reservation = self._routing_loss
self._routing_loss = None
return reservation
class Top1Router(MoeRouter):
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about Switch Transformer
of Google.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert.
select_policy (str, optional): The policy about tokens selection.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Callable = None,
drop_tks: bool = True):
super().__init__(k_value=1,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
self.select_policy = select_policy
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0,
device=get_current_device())).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
logits = F.softmax(inputs, dim=-1)
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
# caculate the auxiliary loss
me = torch.mean(logits, dim=0)
ce = torch.mean(mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce)
self.set_routing_loss(l_aux)
if not self.training and not self.drop_tks:
max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
if self.select_policy == "random":
rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
ranks = moe_cumsum(mask)
elif self.select_policy == "first":
ranks = moe_cumsum(mask)
mask = mask * torch.lt(ranks, capacity)
else:
raise NotImplementedError("Not support such select policy yet.")
ranks = torch.sum(mask * ranks, dim=-1)
if use_kernel:
mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
return logits, mask, dest_idx, num_experts * capacity
else:
ranks = F.one_hot(ranks, num_classes=capacity)
weight = mask * logits.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return combine_weights, sec_mask
class Top2Router(MoeRouter):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about ViT-MoE.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Callable = None,
drop_tks: bool = True):
super().__init__(k_value=2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
# inputs: [s, h]
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
logits = F.softmax(inputs, dim=-1) # logits: [s, e]
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
top1_idx = torch.argmax(logits, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = (mask1 + mask2) # loss: [s, e]
# caculate the auxiliary loss
me = torch.mean(logits, dim=0)
ce = torch.mean(cmask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
self.set_routing_loss(l_aux)
if not self.training and not self.drop_tks:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
rank1 = moe_cumsum(mask1) # rank1: [s, e]
rank2 = moe_cumsum(mask2)
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
mask1 *= torch.lt(rank1, capacity)
mask2 *= torch.lt(rank2, capacity)
rank1 = torch.sum(mask1 * rank1, dim=-1)
rank2 = torch.sum(mask2 * rank2, dim=-1)
if use_kernel:
mask1 = torch.sum(mask1, dim=-1)
mask2 = torch.sum(mask2, dim=-1)
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
return logits, mask, dest_idx, num_experts * capacity
else:
weight1 = mask1 * logits.type_as(inputs)
weight2 = mask2 * logits.type_as(inputs)
rank1_sc = F.one_hot(rank1, num_classes=capacity)
rank2_sc = F.one_hot(rank2, num_classes=capacity)
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
cb_weight = cb_weight1 + cb_weight2
sec_mask = cb_weight.bool()
return cb_weight, sec_mask

View File

@ -32,7 +32,7 @@ def run_test(rank, world_size, port):
moe_layer = MoeLayer(DIM, num_experts, router, exp)
layer_list.append(moe_layer)
model = nn.Sequential(*layer_list)
model = nn.ModuleList(layer_list)
model = model.to(get_current_device())
sync_moe_model_param(model)
@ -49,8 +49,9 @@ def run_test(rank, world_size, port):
grad = torch.randn_like(data)
MOE_CONTEXT.reset_loss()
outputs = model(data)
outputs.backward(grad)
for layer in layer_list:
data, _ = layer(data)
data.backward(grad)
grad_handler.handle_gradient()
assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group)

View File

@ -44,7 +44,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# use matrix multiplication instead of COL_MOE_KERNL in MOE dispatch and combine
layer.use_kernel = False
old_out = layer(tokens)
old_out, _ = layer(tokens)
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad) # get gradient
@ -58,7 +58,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
layer.gate_weight.grad.zero_()
layer.use_kernel = True
new_out = layer(tokens) # get ouputs through colossal kernel
new_out, _ = layer(tokens) # get ouputs through colossal kernel
if data_type == torch.float32:
check_equal(old_out, new_out)

View File

@ -19,20 +19,39 @@ from colossalai.utils import get_current_device
from tests.test_zero.common import CONFIG
class MoeModel(CheckpointModule):
class MoeModel(nn.Module):
def __init__(self, checkpoint: bool = False):
class TestSubModule(CheckpointModule):
def __init__(self):
super().__init__(checkpoint)
self.proj1 = nn.Linear(4, 16)
expert_cls = nn.Linear
expert_args_dict = dict(in_features=16, out_features=16)
self.moe = MoeModule(dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict)
self.proj2 = nn.Linear(16, 4)
self.moe = MoeModule(dim_model=16,
num_experts=8,
use_residual=True,
expert_cls=expert_cls,
**expert_args_dict)
self.proj = nn.Linear(16, 4)
def _forward(self, x):
x, y = self.moe(x)
x = self.proj(x)
return x, y
super().__init__()
self.test_embed = nn.Linear(4, 16)
self.test_transform = TestSubModule()
def forward(self, x):
x = self.proj1(x)
x = self.moe(x)
x = self.proj2(x)
MOE_CONTEXT.reset_loss()
x = self.test_embed(x)
x, y = self.test_transform(x)
MOE_CONTEXT.add_loss(y)
return x

View File

@ -4,6 +4,8 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.nn import MoeLoss
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
@ -26,7 +28,8 @@ def run_model_test(enable_autocast, shard_strategy_class):
shard_strategy = shard_strategy_class()
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
_, train_dataloader, _, _, criterion = get_components_func()
_, train_dataloader, _, optimizer_class, _ = get_components_func()
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
shard_strategy=shard_strategy,
@ -59,7 +62,6 @@ def run_model_test(enable_autocast, shard_strategy_class):
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_CONTEXT.setup(seed=42)
MOE_CONTEXT.reset_loss()
run_model_test()

View File

@ -5,6 +5,7 @@ import pytest
import torch
import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp
from colossalai.nn import MoeLoss
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
@ -60,7 +61,8 @@ def _run_test_sharded_optim_v2(cpu_offload,
return
MOE_CONTEXT.reset_loss()
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
_, train_dataloader, _, optimizer_class, _ = get_components_func()
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(),
shard_strategy=shard_strategy,