pull/5190/head
Xuanlei Zhao 2023-12-15 16:32:32 +08:00
parent 8aef2dba02
commit f66469e209
11 changed files with 304 additions and 186 deletions

View File

@ -4,6 +4,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.moe import SparseMLP from colossalai.moe import SparseMLP
from colossalai.tensor.moe_tensor.api import get_ep_rank, is_moe_tensor
class MixtralSparseMLP: class MixtralSparseMLP:
@ -33,57 +34,81 @@ class MixtralSparseMLP:
Raises: Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm. AssertionError: If the provided module is not an instance of nn.LayerNorm.
""" """
with torch.no_grad():
LazyInitContext.materialize(module)
LazyInitContext.materialize(module) # get the attributes of the module
# get the attributes of the module moe_kwargs = dict(
moe_kwargs = dict( num_experts=module.num_experts,
num_experts=module.num_experts, hidden_size=module.hidden_dim,
hidden_size=module.hidden_dim, intermediate_size=module.ffn_dim,
intermediate_size=module.ffn_dim, router_top_k=module.top_k,
router_top_k=module.top_k, router_norm=True,
# router_capacity_factor_train = module. router_loss=False,
# router_capacity_factor_eval = module. # router_capacity_factor_train = .
# router_min_capacity = module. # router_capacity_factor_eval = .
# router_noisy_policy = module. mlp_activation="silu",
# router_drop_tks = module. mlp_gated=True,
mlp_activation="silu", # enable_load_balance = .
mlp_gated=True, # load_balance_tolerance = .
# enable_load_balance = module. # load_balance_beam_width = .
# load_balance_tolerance = module. # load_balance_group_swap_factor = .
# load_balance_beam_width = module. # enable_kernel = .
# load_balance_group_swap_factor = module. # enable_comm_overlap = .
# enable_kernel = module. # enable_hierarchical_comm = .
# enable_comm_overlap = module. return_gate_logits=True,
# enable_hierarchical_comm = module. )
return_gate_logits=True, dtype = module.gate.weight.dtype
) device = module.gate.weight.device
dtype = module.gate.weight.dtype sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
device = module.gate.weight.device
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device) # cat all experts
w1 = None w1 = None
w2 = None w2 = None
w3 = None w3 = None
for i in module.experts: for i in module.experts:
wi_1 = i.w1.weight.data.transpose(0, 1).unsqueeze(0) # origin
wi_2 = i.w2.weight.data.transpose(0, 1).unsqueeze(0) wi_1 = i.w1.weight.data.clone().transpose(0, 1).unsqueeze(0)
wi_3 = i.w3.weight.data.transpose(0, 1).unsqueeze(0) wi_2 = i.w2.weight.data.clone().transpose(0, 1).unsqueeze(0)
if w1 is None: wi_3 = i.w3.weight.data.clone().transpose(0, 1).unsqueeze(0)
w1 = wi_1 # cat
w1 = wi_1 if w1 is None else torch.cat([w1, wi_1], dim=0)
w2 = wi_2 if w2 is None else torch.cat([w2, wi_2], dim=0)
w3 = wi_3 if w3 is None else torch.cat([w3, wi_3], dim=0)
# get local experts
if is_moe_tensor(sparse_mlp.experts.wi_gate):
ep_rank = get_ep_rank(sparse_mlp.experts.wi_gate)
expert_num = sparse_mlp.experts.wi_gate.shape[0]
expert_slice = slice(ep_rank * expert_num, (ep_rank + 1) * expert_num)
else: else:
w1 = torch.cat([w1, wi_1], dim=0) expert_slice = slice(None)
if w2 is None: w1 = w1[expert_slice].clone().detach()
w2 = wi_2 w2 = w2[expert_slice].clone().detach()
else: w3 = w3[expert_slice].clone().detach()
w2 = torch.cat([w2, wi_2], dim=0) assert (
if w3 is None: w1.shape == sparse_mlp.experts.wi_gate.shape
w3 = wi_3 ), f"current shape: {w1.shape}, target shape:{sparse_mlp.experts.wi_gate.shape}"
else: assert (
w3 = torch.cat([w3, wi_3], dim=0) w2.shape == sparse_mlp.experts.wo.shape
), f"current shape: {w2.shape}, target shape:{sparse_mlp.experts.wo.shape}"
assert (
w3.shape == sparse_mlp.experts.wi_up.shape
), f"current shape: {w3.shape}, target shape:{sparse_mlp.experts.wi_up.shape}"
sparse_mlp.experts.wi_gate.data = w1[:2] # assign new param to colossal moe moudle
sparse_mlp.experts.wi_up.data = w3[:2] sparse_mlp.experts.wi_gate.data = w1
sparse_mlp.experts.wo.data = w2[:2] sparse_mlp.experts.wi_up.data = w3
sparse_mlp.gate_weight = module.gate.weight sparse_mlp.experts.wo.data = w2
sparse_mlp.gate_weight = module.gate.weight
return sparse_mlp.to(dtype).to(device) # TODO: fix
# the old weight is referenced somewhere so we can not del it.
# Change data pointer of old weight to release memory.
# The pointer will not be used and can be any pointer.
for i in module.experts:
i.w1.weight.data = w1
i.w2.weight.data = w2
i.w3.weight.data = w3
return sparse_mlp

View File

@ -0,0 +1,43 @@
from setuptools import find_packages, setup
def fetch_requirements(path):
with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()]
def fetch_readme():
with open("README.md", encoding="utf-8") as f:
return f.read()
def fetch_version():
with open("version.txt", "r") as f:
return f.read().strip()
setup(
name="colossal_moe",
version=fetch_version(),
packages=find_packages(
exclude=(
"tests",
"benchmarks",
"*.egg-info",
)
),
description="Colossal-AI MoE",
long_description=fetch_readme(),
long_description_content_type="text/markdown",
license="Apache Software License 2.0",
url="https://github.com/hpcaitech",
install_requires=fetch_requirements("requirements.txt"),
python_requires=">=3.6",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Environment :: GPU :: NVIDIA CUDA",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: System :: Distributed Computing",
],
)

View File

@ -0,0 +1,31 @@
import copy
import torch
from colossal_moe.models.mixtral_layer import MixtralSparseMLP
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
class Config:
def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_local_experts = num_local_experts
self.num_experts_per_tok = num_experts_per_tok
self.hidden_act = hidden_act
def test_moe_layer():
config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu")
mistral_moe = MixtralSparseMoeBlock(config).cuda()
colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda()
data = torch.randn(2, 8, 4).cuda()
mistral_output = mistral_moe(data)[0]
colossal_output = colossal_moe(data)[0]
assert torch.allclose(
mistral_output, colossal_output
), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}"
if __name__ == "__main__":
test_moe_layer()

View File

@ -1,21 +1,20 @@
import argparse import argparse
from typing import Dict
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer, T5Tokenizer from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance from colossalai.lazy import LazyInitContext
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe import MOE_MANAGER, apply_load_balance
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -23,21 +22,6 @@ def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()} return {k: v.to(device) for k, v in batch.items()}
def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict:
texts = ["<pad>" + sample["prompt"] + sample["completion"] for sample in batch]
data = tokenizer(
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_length,
add_special_tokens=False,
)
data = {k: v.cuda() for k, v in data.items()}
data["labels"] = data["input_ids"].clone()
return data
class RandomDataset(Dataset): class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
self.num_samples = num_samples self.num_samples = num_samples
@ -188,7 +172,6 @@ def main():
# Launch ColossalAI # Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed) colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator() coordinator = DistCoordinator()
test_mode = args.model_name == "test"
# Set plugin # Set plugin
booster_kwargs = {} booster_kwargs = {}
@ -247,15 +230,20 @@ def main():
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build OpenMoe model # Build OpenMoe model
config = MixtralConfig( # config = MixtralConfig(
hidden_size=32, # hidden_size=2048,
intermediate_size=64, # intermediate_size=4096,
num_hidden_layers=4, # num_hidden_layers=4,
num_attention_heads=4, # num_attention_heads=4,
num_key_value_heads=4, # num_key_value_heads=4,
use_cache=False, # use_cache=False,
) # )
model = MixtralForCausalLM(config).bfloat16() config = MixtralConfig.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
config.use_cache = False
# config.num_local_experts = 1
init_ctx = LazyInitContext(default_device=get_current_device())
with init_ctx:
model = MixtralForCausalLM(config).bfloat16()
coordinator.print_on_master(f"Finish init model with config:\n{config}") coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing # Enable gradient checkpointing
@ -270,7 +258,7 @@ def main():
) )
# Set optimizer # Set optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
@ -292,7 +280,6 @@ def main():
) as pbar: ) as pbar:
for step in pbar: for step in pbar:
if use_pipeline: if use_pipeline:
exit()
# Forward pass # Forward pass
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(
train_dataloader_iter, train_dataloader_iter,
@ -307,11 +294,9 @@ def main():
loss = outputs["loss"] loss = outputs["loss"]
pbar.set_postfix({"loss": loss.item()}) pbar.set_postfix({"loss": loss.item()})
else: else:
print("1111111\n\n")
# Forward pass # Forward pass
data = next(train_dataloader_iter) data = next(train_dataloader_iter)
data = move_to_cuda(data, torch.cuda.current_device()) data = move_to_cuda(data, torch.cuda.current_device())
print(data)
outputs = model(**data) outputs = model(**data)
loss = outputs["loss"] loss = outputs["loss"]
# Backward # Backward

View File

@ -0,0 +1 @@
1.0.0

View File

@ -1,6 +1,7 @@
from .checkpoint import MoECheckpintIO from .checkpoint import MoECheckpintIO
from .experts import MLPExperts from .experts import MLPExperts
from .layers import SparseMLP from .layers import SparseMLP, apply_load_balance
from .manager import MOE_MANAGER
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator from .utils import NormalNoiseGenerator, UniformNoiseGenerator
@ -14,4 +15,6 @@ __all__ = [
"UniformNoiseGenerator", "UniformNoiseGenerator",
"SparseMLP", "SparseMLP",
"MoECheckpintIO", "MoECheckpintIO",
"MOE_MANAGER",
"apply_load_balance",
] ]

View File

@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
self.ep_size = 1 self.ep_size = 1
if gated: if gated:
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2)) self.wi_gate = nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
)
)
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
else: else:
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))

View File

@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
router_top_k: int = 1, router_top_k: int = 1,
router_loss: bool = True,
router_norm: bool = False,
router_capacity_factor_train: float = 1.25, router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0, router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4, router_min_capacity: int = 4,
@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
enable_kernel: bool = False, enable_kernel: bool = False,
enable_comm_overlap: bool = False, enable_comm_overlap: bool = False,
enable_hierarchical_comm: bool = False, enable_hierarchical_comm: bool = False,
return_gate_logits: bool = False,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.num_experts = num_experts self.num_experts = num_experts
self.gated = mlp_gated self.gated = mlp_gated
self.return_gate_logits = return_gate_logits
self.enable_kernel = enable_kernel self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap self.enable_comm_overlap = enable_comm_overlap
self.expert_parallel = MOE_MANAGER.get_parallel() self.expert_parallel = MOE_MANAGER.get_parallel()
self.router_loss = router_loss
self.router_norm = router_norm
# moe router # moe router
noisy_func = get_noise_generator(router_noisy_policy, num_experts) noisy_func = get_noise_generator(router_noisy_policy, num_experts)
@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
tokens = inputs.reshape(-1, self.hidden_size) tokens = inputs.reshape(-1, self.hidden_size)
# the data type of the inputs in the gating should be fp32 # the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float) gate_logits = F.linear(tokens, self.gate_weight)
fp32_weight = self.gate_weight.to(torch.float) gate_output = gate_logits.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
# update expert load # update expert load
if self.enable_load_balance == True: if self.enable_load_balance == True:
@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
# the result from the router # the result from the router
used_capacity, *route_result_list = self.router( used_capacity, *route_result_list = self.router(
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) inputs=gate_output,
use_kernel=self.enable_kernel,
ep_group=self.ep_group,
use_loss=self.router_loss,
use_norm=self.router_norm,
)
# dispatch_data: (num_experts, capacity, hidden_size) # dispatch_data: (num_experts, capacity, hidden_size)
if self.enable_kernel: if self.enable_kernel:
@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
# expert_output: (num_groups, num_experts, capacity, hidden_size) # expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP": if self.expert_parallel == "EP":
expert_output = self._ep_process( expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
elif self.expert_parallel == "TP": elif self.expert_parallel == "TP":
expert_output = self._tp_process( expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
elif self.expert_parallel is None: elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data) expert_output = self._local_process(dispatch_data)
else: else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n" raise NotImplementedError(
"Please use Experts build function.") "This kind of communication has not been implemented yet.\n" "Please use Experts build function."
)
if self.enable_kernel: if self.enable_kernel:
expert_output = expert_output.reshape(-1, self.hidden_size) expert_output = expert_output.reshape(-1, self.hidden_size)
@ -204,7 +207,11 @@ class SparseMLP(nn.Module):
ans = torch.matmul(combine_weights, expert_output) ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape) ans = ans.reshape(inputs.shape)
return ans
if self.return_gate_logits:
return ans, gate_logits
else:
return ans
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
expert_in = expert_in.unsqueeze(0) expert_in = expert_in.unsqueeze(0)
@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
return expert_out return expert_out
def _ep_process( def _ep_process(
self, self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Expert Parallel Expert Parallel
@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
""" """
if not overlap or dist.get_world_size(self.ep_group) == 1: if not overlap or dist.get_world_size(self.ep_group) == 1:
if self.ep_hierarchical_group is not None: if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank) expert_input = HierarchicalAllToAll.apply(
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input) expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank) expert_output = HierarchicalAllToAll.apply(
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
)
return expert_output return expert_output
else: else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4 NUM_CHUNK = 4
NUM_STAGES = 4 NUM_STAGES = 4
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
chunk_size = dispatch_data.shape[1] // NUM_CHUNK chunk_size = dispatch_data.shape[1] // NUM_CHUNK
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
dispatch_data = dispatch_data.reshape(*input_shape) dispatch_data = dispatch_data.reshape(*input_shape)
@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
for i in range(NUM_CHUNK + NUM_STAGES - 1): for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None: if expert_out is not None:
expert_out.handle.wait() expert_out.handle.wait()
output[:, :, offset:offset + chunk_size, :] = expert_out.data output[:, :, offset : offset + chunk_size, :] = expert_out.data
offset += chunk_size offset += chunk_size
expert_out = None expert_out = None
# all2all last output # all2all last output
if _expert_out is not None: if _expert_out is not None:
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) expert_out = Capsule(
*AllToAll.apply(_expert_out.data, self.ep_group, True),
)
_expert_out = None _expert_out = None
# all2all next input # all2all next input
@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
return output return output
def _tp_process( def _tp_process(
self, self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
without overlap: without overlap:
@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4 NUM_CHUNK = 4
NUM_STAGES = 4 NUM_STAGES = 4
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ assert (
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" dispatch_data.shape[0] % NUM_CHUNK == 0
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0) chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data) output = torch.empty_like(dispatch_data)

View File

@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC):
drop_tks (bool, optional): Whether drops tokens in evaluation drop_tks (bool, optional): Whether drops tokens in evaluation
""" """
def __init__(self, def __init__(
k_value: int, self,
capacity_factor_train: float, k_value: int,
capacity_factor_eval: float, capacity_factor_train: float,
min_capacity: int, capacity_factor_eval: float,
noisy_func: Optional[Callable] = None, min_capacity: int,
drop_tks: bool = True, noisy_func: Optional[Callable] = None,
use_kernel: bool = False): drop_tks: bool = True,
use_kernel: bool = False,
):
super().__init__() super().__init__()
self.k_value = k_value self.k_value = k_value
self.capacity_factor_train = capacity_factor_train self.capacity_factor_train = capacity_factor_train
@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC):
if router_probs.dim() == expert_indices.dim() == 2: if router_probs.dim() == expert_indices.dim() == 2:
router_probs = router_probs.unsqueeze(0) router_probs = router_probs.unsqueeze(0)
expert_indices = expert_indices.unsqueeze(0) expert_indices = expert_indices.unsqueeze(0)
assert router_probs.dim() == expert_indices.dim() == 3, \ assert (
"router_probs must be 3D tensor and expert_indices must be 4D tensor" router_probs.dim() == expert_indices.dim() == 3
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_indices, num_experts) expert_mask = F.one_hot(expert_indices, num_experts)
@ -122,25 +125,28 @@ class Top1Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation drop_tks (bool, optional): Whether drops tokens in evaluation
""" """
def __init__(self, def __init__(
capacity_factor_train: float = 1.25, self,
capacity_factor_eval: float = 2.0, capacity_factor_train: float = 1.25,
min_capacity: int = 4, capacity_factor_eval: float = 2.0,
select_policy: str = "first", min_capacity: int = 4,
noisy_func: Optional[Callable] = None, select_policy: str = "first",
drop_tks: bool = True): noisy_func: Optional[Callable] = None,
super().__init__(k_value=1, drop_tks: bool = True,
capacity_factor_train=capacity_factor_train, ):
capacity_factor_eval=capacity_factor_eval, super().__init__(
min_capacity=min_capacity, k_value=1,
noisy_func=noisy_func, capacity_factor_train=capacity_factor_train,
drop_tks=drop_tks) capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
self.select_policy = select_policy self.select_policy = select_policy
assert select_policy in {"first", "random"} assert select_policy in {"first", "random"}
if select_policy == "random": if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform( self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_current_device()), low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device())
high=torch.tensor(1.0, device=get_current_device())
).rsample ).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
@ -200,7 +206,7 @@ class Top1Router(MoeRouter):
weight = mask * probs.type_as(inputs) weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool() sec_mask = combine_weights.bool()
return used_capacity, combine_weights, sec_mask return used_capacity, combine_weights, sec_mask, probs
class Top2Router(MoeRouter): class Top2Router(MoeRouter):
@ -216,20 +222,31 @@ class Top2Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation. drop_tks (bool, optional): Whether drops tokens in evaluation.
""" """
def __init__(self, def __init__(
capacity_factor_train: float = 1.25, self,
capacity_factor_eval: float = 2.0, capacity_factor_train: float = 1.25,
min_capacity: int = 4, capacity_factor_eval: float = 2.0,
noisy_func: Optional[Callable] = None, min_capacity: int = 4,
drop_tks: bool = True): noisy_func: Optional[Callable] = None,
super().__init__(k_value=2, drop_tks: bool = True,
capacity_factor_train=capacity_factor_train, ):
capacity_factor_eval=capacity_factor_eval, super().__init__(
min_capacity=min_capacity, k_value=2,
noisy_func=noisy_func, capacity_factor_train=capacity_factor_train,
drop_tks=drop_tks) capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_norm: bool = False,
use_loss: bool = True,
) -> Tuple:
""" """
Args: Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
@ -246,6 +263,10 @@ class Top2Router(MoeRouter):
assert inputs.dtype == torch.float assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1) probs = F.softmax(inputs, dim=-1)
if use_norm:
routing_weights, _ = torch.topk(probs, 2, dim=-1)
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
num_experts = probs.size(-1) num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape) capacity = self.get_capacity(inputs.shape)
@ -255,21 +276,22 @@ class Top2Router(MoeRouter):
top2_idx = torch.argmax(logits_except1, dim=-1) top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = (mask1 + mask2) # loss: [s, e] cmask = mask1 + mask2 # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss # calculate loss
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) if use_loss:
self.set_aux_loss(probs, expert_indices, num_experts) expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_z_loss(inputs) self.set_aux_loss(probs, expert_indices, num_experts)
self.pop_router_loss() self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None: if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(cmask, dim=0)) max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item() capacity = max_num.item()
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
rank2 += torch.sum(mask1, dim=-2, keepdim=True) rank2 += torch.sum(mask1, dim=-2, keepdim=True)
@ -336,15 +358,18 @@ class TopKRouter(MoeRouter):
oversubscribed / reach capacity. oversubscribed / reach capacity.
""" """
def __init__(self, def __init__(
num_selected_experts: int, self,
capacity_factor_train: float = 1.25, num_selected_experts: int,
capacity_factor_eval: float = 2.0, capacity_factor_train: float = 1.25,
min_capacity: int = 4, capacity_factor_eval: float = 2.0,
noisy_func: Optional[Callable] = None, min_capacity: int = 4,
drop_tks: bool = True): noisy_func: Optional[Callable] = None,
super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks: bool = True,
drop_tks) ):
super().__init__(
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
)
def forward( def forward(
self, self,
@ -410,7 +435,7 @@ class TopKRouter(MoeRouter):
# The combine array will be used for combining expert outputs, scaled by the # The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity]. # expert_capacity].
combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask) combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
return combine_array, dispatch_mask return combine_array, dispatch_mask

View File

@ -13,7 +13,6 @@ from colossalai.utils import get_current_device
class ForceFP32Parameter(torch.nn.Parameter): class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None): def half(self, memory_format=None):
return self.data.clone() return self.data.clone()
@ -84,6 +83,8 @@ def get_activation(act: str) -> Callable:
return torch.nn.GELU() return torch.nn.GELU()
elif act == "swiglu": elif act == "swiglu":
return SwiGLU return SwiGLU
elif act == "silu":
return torch.nn.SiLU()
else: else:
raise NotImplementedError("Unsupported activation function") raise NotImplementedError("Unsupported activation function")
@ -142,7 +143,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
epsize_param_dict = dict() epsize_param_dict = dict()
for param in model.parameters(): for param in model.parameters():
if not is_moe_tensor(param): if not is_moe_tensor(param):
ep_size = 1 # set ep_size to 1 for dp parameters ep_size = 1 # set ep_size to 1 for dp parameters
else: else:
ep_size = get_ep_size(param) ep_size = get_ep_size(param)
if ep_size not in epsize_param_dict: if ep_size not in epsize_param_dict:
@ -193,18 +194,13 @@ def create_ep_hierarchical_group(
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually." assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
nproc_per_node = int(nproc_per_node) nproc_per_node = int(nproc_per_node)
else: else:
assert dist.get_world_size() % nproc_per_node == 0, \ assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size."
"nproc_per_node should be a divisor of world_size."
num_node = dist.get_world_size() // nproc_per_node num_node = dist.get_world_size() // nproc_per_node
intra_src_rank = None intra_src_rank = None
ep_intra_node_group = None ep_intra_node_group = None
for i in range(num_node): for i in range(num_node):
ep_intra_ranks = [ ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks]
i * nproc_per_node + j
for j in range(nproc_per_node)
if j in ep_group_ranks
]
group = dist.new_group(ep_intra_ranks) group = dist.new_group(ep_intra_ranks)
if rank in ep_intra_ranks: if rank in ep_intra_ranks:
assert ep_intra_node_group is None assert ep_intra_node_group is None
@ -212,10 +208,7 @@ def create_ep_hierarchical_group(
intra_src_rank = ep_intra_ranks[0] intra_src_rank = ep_intra_ranks[0]
ep_inter_node_group = None ep_inter_node_group = None
ep_inter_ranks = [ ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)]
ep_group_ranks[0] + i * nproc_per_node
for i in range(num_node)
]
if len(ep_inter_ranks) > 1: if len(ep_inter_ranks) > 1:
group = dist.new_group(ep_inter_ranks) group = dist.new_group(ep_inter_ranks)
if rank in ep_inter_ranks: if rank in ep_inter_ranks: