pull/5190/head
Xuanlei Zhao 2023-12-15 16:32:32 +08:00
parent 8aef2dba02
commit f66469e209
11 changed files with 304 additions and 186 deletions

View File

@ -4,6 +4,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from colossalai.lazy import LazyInitContext
from colossalai.moe import SparseMLP
from colossalai.tensor.moe_tensor.api import get_ep_rank, is_moe_tensor
class MixtralSparseMLP:
@ -33,57 +34,81 @@ class MixtralSparseMLP:
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
with torch.no_grad():
LazyInitContext.materialize(module)
LazyInitContext.materialize(module)
# get the attributes of the module
moe_kwargs = dict(
num_experts=module.num_experts,
hidden_size=module.hidden_dim,
intermediate_size=module.ffn_dim,
router_top_k=module.top_k,
# router_capacity_factor_train = module.
# router_capacity_factor_eval = module.
# router_min_capacity = module.
# router_noisy_policy = module.
# router_drop_tks = module.
mlp_activation="silu",
mlp_gated=True,
# enable_load_balance = module.
# load_balance_tolerance = module.
# load_balance_beam_width = module.
# load_balance_group_swap_factor = module.
# enable_kernel = module.
# enable_comm_overlap = module.
# enable_hierarchical_comm = module.
return_gate_logits=True,
)
dtype = module.gate.weight.dtype
device = module.gate.weight.device
# get the attributes of the module
moe_kwargs = dict(
num_experts=module.num_experts,
hidden_size=module.hidden_dim,
intermediate_size=module.ffn_dim,
router_top_k=module.top_k,
router_norm=True,
router_loss=False,
# router_capacity_factor_train = .
# router_capacity_factor_eval = .
mlp_activation="silu",
mlp_gated=True,
# enable_load_balance = .
# load_balance_tolerance = .
# load_balance_beam_width = .
# load_balance_group_swap_factor = .
# enable_kernel = .
# enable_comm_overlap = .
# enable_hierarchical_comm = .
return_gate_logits=True,
)
dtype = module.gate.weight.dtype
device = module.gate.weight.device
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
w1 = None
w2 = None
w3 = None
for i in module.experts:
wi_1 = i.w1.weight.data.transpose(0, 1).unsqueeze(0)
wi_2 = i.w2.weight.data.transpose(0, 1).unsqueeze(0)
wi_3 = i.w3.weight.data.transpose(0, 1).unsqueeze(0)
if w1 is None:
w1 = wi_1
# cat all experts
w1 = None
w2 = None
w3 = None
for i in module.experts:
# origin
wi_1 = i.w1.weight.data.clone().transpose(0, 1).unsqueeze(0)
wi_2 = i.w2.weight.data.clone().transpose(0, 1).unsqueeze(0)
wi_3 = i.w3.weight.data.clone().transpose(0, 1).unsqueeze(0)
# cat
w1 = wi_1 if w1 is None else torch.cat([w1, wi_1], dim=0)
w2 = wi_2 if w2 is None else torch.cat([w2, wi_2], dim=0)
w3 = wi_3 if w3 is None else torch.cat([w3, wi_3], dim=0)
# get local experts
if is_moe_tensor(sparse_mlp.experts.wi_gate):
ep_rank = get_ep_rank(sparse_mlp.experts.wi_gate)
expert_num = sparse_mlp.experts.wi_gate.shape[0]
expert_slice = slice(ep_rank * expert_num, (ep_rank + 1) * expert_num)
else:
w1 = torch.cat([w1, wi_1], dim=0)
if w2 is None:
w2 = wi_2
else:
w2 = torch.cat([w2, wi_2], dim=0)
if w3 is None:
w3 = wi_3
else:
w3 = torch.cat([w3, wi_3], dim=0)
expert_slice = slice(None)
w1 = w1[expert_slice].clone().detach()
w2 = w2[expert_slice].clone().detach()
w3 = w3[expert_slice].clone().detach()
assert (
w1.shape == sparse_mlp.experts.wi_gate.shape
), f"current shape: {w1.shape}, target shape:{sparse_mlp.experts.wi_gate.shape}"
assert (
w2.shape == sparse_mlp.experts.wo.shape
), f"current shape: {w2.shape}, target shape:{sparse_mlp.experts.wo.shape}"
assert (
w3.shape == sparse_mlp.experts.wi_up.shape
), f"current shape: {w3.shape}, target shape:{sparse_mlp.experts.wi_up.shape}"
sparse_mlp.experts.wi_gate.data = w1[:2]
sparse_mlp.experts.wi_up.data = w3[:2]
sparse_mlp.experts.wo.data = w2[:2]
sparse_mlp.gate_weight = module.gate.weight
# assign new param to colossal moe moudle
sparse_mlp.experts.wi_gate.data = w1
sparse_mlp.experts.wi_up.data = w3
sparse_mlp.experts.wo.data = w2
sparse_mlp.gate_weight = module.gate.weight
return sparse_mlp.to(dtype).to(device)
# TODO: fix
# the old weight is referenced somewhere so we can not del it.
# Change data pointer of old weight to release memory.
# The pointer will not be used and can be any pointer.
for i in module.experts:
i.w1.weight.data = w1
i.w2.weight.data = w2
i.w3.weight.data = w3
return sparse_mlp

View File

@ -0,0 +1,43 @@
from setuptools import find_packages, setup
def fetch_requirements(path):
with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()]
def fetch_readme():
with open("README.md", encoding="utf-8") as f:
return f.read()
def fetch_version():
with open("version.txt", "r") as f:
return f.read().strip()
setup(
name="colossal_moe",
version=fetch_version(),
packages=find_packages(
exclude=(
"tests",
"benchmarks",
"*.egg-info",
)
),
description="Colossal-AI MoE",
long_description=fetch_readme(),
long_description_content_type="text/markdown",
license="Apache Software License 2.0",
url="https://github.com/hpcaitech",
install_requires=fetch_requirements("requirements.txt"),
python_requires=">=3.6",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Environment :: GPU :: NVIDIA CUDA",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: System :: Distributed Computing",
],
)

View File

@ -0,0 +1,31 @@
import copy
import torch
from colossal_moe.models.mixtral_layer import MixtralSparseMLP
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
class Config:
def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_local_experts = num_local_experts
self.num_experts_per_tok = num_experts_per_tok
self.hidden_act = hidden_act
def test_moe_layer():
config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu")
mistral_moe = MixtralSparseMoeBlock(config).cuda()
colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda()
data = torch.randn(2, 8, 4).cuda()
mistral_output = mistral_moe(data)[0]
colossal_output = colossal_moe(data)[0]
assert torch.allclose(
mistral_output, colossal_output
), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}"
if __name__ == "__main__":
test_moe_layer()

View File

@ -1,21 +1,20 @@
import argparse
from typing import Dict
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer, T5Tokenizer
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.lazy import LazyInitContext
from colossalai.moe import MOE_MANAGER, apply_load_balance
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
@ -23,21 +22,6 @@ def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict:
texts = ["<pad>" + sample["prompt"] + sample["completion"] for sample in batch]
data = tokenizer(
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_length,
add_special_tokens=False,
)
data = {k: v.cuda() for k, v in data.items()}
data["labels"] = data["input_ids"].clone()
return data
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
self.num_samples = num_samples
@ -188,7 +172,6 @@ def main():
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
test_mode = args.model_name == "test"
# Set plugin
booster_kwargs = {}
@ -247,15 +230,20 @@ def main():
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build OpenMoe model
config = MixtralConfig(
hidden_size=32,
intermediate_size=64,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=4,
use_cache=False,
)
model = MixtralForCausalLM(config).bfloat16()
# config = MixtralConfig(
# hidden_size=2048,
# intermediate_size=4096,
# num_hidden_layers=4,
# num_attention_heads=4,
# num_key_value_heads=4,
# use_cache=False,
# )
config = MixtralConfig.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
config.use_cache = False
# config.num_local_experts = 1
init_ctx = LazyInitContext(default_device=get_current_device())
with init_ctx:
model = MixtralForCausalLM(config).bfloat16()
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing
@ -270,7 +258,7 @@ def main():
)
# Set optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
@ -292,7 +280,6 @@ def main():
) as pbar:
for step in pbar:
if use_pipeline:
exit()
# Forward pass
outputs = booster.execute_pipeline(
train_dataloader_iter,
@ -307,11 +294,9 @@ def main():
loss = outputs["loss"]
pbar.set_postfix({"loss": loss.item()})
else:
print("1111111\n\n")
# Forward pass
data = next(train_dataloader_iter)
data = move_to_cuda(data, torch.cuda.current_device())
print(data)
outputs = model(**data)
loss = outputs["loss"]
# Backward

View File

@ -0,0 +1 @@
1.0.0

View File

@ -1,6 +1,7 @@
from .checkpoint import MoECheckpintIO
from .experts import MLPExperts
from .layers import SparseMLP
from .layers import SparseMLP, apply_load_balance
from .manager import MOE_MANAGER
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
@ -14,4 +15,6 @@ __all__ = [
"UniformNoiseGenerator",
"SparseMLP",
"MoECheckpintIO",
"MOE_MANAGER",
"apply_load_balance",
]

View File

@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
self.ep_size = 1
if gated:
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
self.wi_gate = nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
)
)
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
else:
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))

View File

@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
router_top_k: int = 1,
router_loss: bool = True,
router_norm: bool = False,
router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4,
@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
enable_hierarchical_comm: bool = False,
return_gate_logits: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_experts = num_experts
self.gated = mlp_gated
self.return_gate_logits = return_gate_logits
self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap
self.expert_parallel = MOE_MANAGER.get_parallel()
self.router_loss = router_loss
self.router_norm = router_norm
# moe router
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
tokens = inputs.reshape(-1, self.hidden_size)
# the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float)
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
gate_logits = F.linear(tokens, self.gate_weight)
gate_output = gate_logits.to(torch.float)
# update expert load
if self.enable_load_balance == True:
@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
# the result from the router
used_capacity, *route_result_list = self.router(
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
inputs=gate_output,
use_kernel=self.enable_kernel,
ep_group=self.ep_group,
use_loss=self.router_loss,
use_norm=self.router_norm,
)
# dispatch_data: (num_experts, capacity, hidden_size)
if self.enable_kernel:
@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP":
expert_output = self._ep_process(
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel == "TP":
expert_output = self._tp_process(
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data)
else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n"
"Please use Experts build function.")
raise NotImplementedError(
"This kind of communication has not been implemented yet.\n" "Please use Experts build function."
)
if self.enable_kernel:
expert_output = expert_output.reshape(-1, self.hidden_size)
@ -204,7 +207,11 @@ class SparseMLP(nn.Module):
ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape)
return ans
if self.return_gate_logits:
return ans, gate_logits
else:
return ans
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
expert_in = expert_in.unsqueeze(0)
@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
return expert_out
def _ep_process(
self,
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
) -> torch.Tensor:
"""
Expert Parallel
@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
expert_input = HierarchicalAllToAll.apply(
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
expert_output = HierarchicalAllToAll.apply(
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
)
return expert_output
else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4
NUM_STAGES = 4
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
dispatch_data = dispatch_data.reshape(*input_shape)
@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None:
expert_out.handle.wait()
output[:, :, offset:offset + chunk_size, :] = expert_out.data
output[:, :, offset : offset + chunk_size, :] = expert_out.data
offset += chunk_size
expert_out = None
# all2all last output
if _expert_out is not None:
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
expert_out = Capsule(
*AllToAll.apply(_expert_out.data, self.ep_group, True),
)
_expert_out = None
# all2all next input
@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
return output
def _tp_process(
self,
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
) -> torch.Tensor:
"""
without overlap:
@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4
NUM_STAGES = 4
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
assert (
dispatch_data.shape[0] % NUM_CHUNK == 0
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data)

View File

@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC):
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(self,
k_value: int,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
use_kernel: bool = False):
def __init__(
self,
k_value: int,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
use_kernel: bool = False,
):
super().__init__()
self.k_value = k_value
self.capacity_factor_train = capacity_factor_train
@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC):
if router_probs.dim() == expert_indices.dim() == 2:
router_probs = router_probs.unsqueeze(0)
expert_indices = expert_indices.unsqueeze(0)
assert router_probs.dim() == expert_indices.dim() == 3, \
"router_probs must be 3D tensor and expert_indices must be 4D tensor"
assert (
router_probs.dim() == expert_indices.dim() == 3
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_indices, num_experts)
@ -122,25 +125,28 @@ class Top1Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(k_value=1,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
k_value=1,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
self.select_policy = select_policy
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0, device=get_current_device())
low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device())
).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
@ -200,7 +206,7 @@ class Top1Router(MoeRouter):
weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return used_capacity, combine_weights, sec_mask
return used_capacity, combine_weights, sec_mask, probs
class Top2Router(MoeRouter):
@ -216,20 +222,31 @@ class Top2Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(k_value=2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
k_value=2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_norm: bool = False,
use_loss: bool = True,
) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
@ -246,6 +263,10 @@ class Top2Router(MoeRouter):
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
if use_norm:
routing_weights, _ = torch.topk(probs, 2, dim=-1)
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape)
@ -255,21 +276,22 @@ class Top2Router(MoeRouter):
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = (mask1 + mask2) # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
cmask = mask1 + mask2 # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if use_loss:
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
@ -336,15 +358,18 @@ class TopKRouter(MoeRouter):
oversubscribed / reach capacity.
"""
def __init__(self,
num_selected_experts: int,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func,
drop_tks)
def __init__(
self,
num_selected_experts: int,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
)
def forward(
self,
@ -410,7 +435,7 @@ class TopKRouter(MoeRouter):
# The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity].
combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask)
combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
return combine_array, dispatch_mask

View File

@ -13,7 +13,6 @@ from colossalai.utils import get_current_device
class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None):
return self.data.clone()
@ -84,6 +83,8 @@ def get_activation(act: str) -> Callable:
return torch.nn.GELU()
elif act == "swiglu":
return SwiGLU
elif act == "silu":
return torch.nn.SiLU()
else:
raise NotImplementedError("Unsupported activation function")
@ -142,7 +143,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
epsize_param_dict = dict()
for param in model.parameters():
if not is_moe_tensor(param):
ep_size = 1 # set ep_size to 1 for dp parameters
ep_size = 1 # set ep_size to 1 for dp parameters
else:
ep_size = get_ep_size(param)
if ep_size not in epsize_param_dict:
@ -193,18 +194,13 @@ def create_ep_hierarchical_group(
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
nproc_per_node = int(nproc_per_node)
else:
assert dist.get_world_size() % nproc_per_node == 0, \
"nproc_per_node should be a divisor of world_size."
assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size."
num_node = dist.get_world_size() // nproc_per_node
intra_src_rank = None
ep_intra_node_group = None
for i in range(num_node):
ep_intra_ranks = [
i * nproc_per_node + j
for j in range(nproc_per_node)
if j in ep_group_ranks
]
ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks]
group = dist.new_group(ep_intra_ranks)
if rank in ep_intra_ranks:
assert ep_intra_node_group is None
@ -212,10 +208,7 @@ def create_ep_hierarchical_group(
intra_src_rank = ep_intra_ranks[0]
ep_inter_node_group = None
ep_inter_ranks = [
ep_group_ranks[0] + i * nproc_per_node
for i in range(num_node)
]
ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)]
if len(ep_inter_ranks) > 1:
group = dist.new_group(ep_inter_ranks)
if rank in ep_inter_ranks: