mirror of https://github.com/hpcaitech/ColossalAI
update
parent
8aef2dba02
commit
f66469e209
|
@ -4,6 +4,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe import SparseMLP
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_rank, is_moe_tensor
|
||||
|
||||
|
||||
class MixtralSparseMLP:
|
||||
|
@ -33,57 +34,81 @@ class MixtralSparseMLP:
|
|||
Raises:
|
||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
LazyInitContext.materialize(module)
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes of the module
|
||||
moe_kwargs = dict(
|
||||
num_experts=module.num_experts,
|
||||
hidden_size=module.hidden_dim,
|
||||
intermediate_size=module.ffn_dim,
|
||||
router_top_k=module.top_k,
|
||||
# router_capacity_factor_train = module.
|
||||
# router_capacity_factor_eval = module.
|
||||
# router_min_capacity = module.
|
||||
# router_noisy_policy = module.
|
||||
# router_drop_tks = module.
|
||||
mlp_activation="silu",
|
||||
mlp_gated=True,
|
||||
# enable_load_balance = module.
|
||||
# load_balance_tolerance = module.
|
||||
# load_balance_beam_width = module.
|
||||
# load_balance_group_swap_factor = module.
|
||||
# enable_kernel = module.
|
||||
# enable_comm_overlap = module.
|
||||
# enable_hierarchical_comm = module.
|
||||
return_gate_logits=True,
|
||||
)
|
||||
dtype = module.gate.weight.dtype
|
||||
device = module.gate.weight.device
|
||||
# get the attributes of the module
|
||||
moe_kwargs = dict(
|
||||
num_experts=module.num_experts,
|
||||
hidden_size=module.hidden_dim,
|
||||
intermediate_size=module.ffn_dim,
|
||||
router_top_k=module.top_k,
|
||||
router_norm=True,
|
||||
router_loss=False,
|
||||
# router_capacity_factor_train = .
|
||||
# router_capacity_factor_eval = .
|
||||
mlp_activation="silu",
|
||||
mlp_gated=True,
|
||||
# enable_load_balance = .
|
||||
# load_balance_tolerance = .
|
||||
# load_balance_beam_width = .
|
||||
# load_balance_group_swap_factor = .
|
||||
# enable_kernel = .
|
||||
# enable_comm_overlap = .
|
||||
# enable_hierarchical_comm = .
|
||||
return_gate_logits=True,
|
||||
)
|
||||
dtype = module.gate.weight.dtype
|
||||
device = module.gate.weight.device
|
||||
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
|
||||
|
||||
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
|
||||
w1 = None
|
||||
w2 = None
|
||||
w3 = None
|
||||
for i in module.experts:
|
||||
wi_1 = i.w1.weight.data.transpose(0, 1).unsqueeze(0)
|
||||
wi_2 = i.w2.weight.data.transpose(0, 1).unsqueeze(0)
|
||||
wi_3 = i.w3.weight.data.transpose(0, 1).unsqueeze(0)
|
||||
if w1 is None:
|
||||
w1 = wi_1
|
||||
# cat all experts
|
||||
w1 = None
|
||||
w2 = None
|
||||
w3 = None
|
||||
for i in module.experts:
|
||||
# origin
|
||||
wi_1 = i.w1.weight.data.clone().transpose(0, 1).unsqueeze(0)
|
||||
wi_2 = i.w2.weight.data.clone().transpose(0, 1).unsqueeze(0)
|
||||
wi_3 = i.w3.weight.data.clone().transpose(0, 1).unsqueeze(0)
|
||||
# cat
|
||||
w1 = wi_1 if w1 is None else torch.cat([w1, wi_1], dim=0)
|
||||
w2 = wi_2 if w2 is None else torch.cat([w2, wi_2], dim=0)
|
||||
w3 = wi_3 if w3 is None else torch.cat([w3, wi_3], dim=0)
|
||||
|
||||
# get local experts
|
||||
if is_moe_tensor(sparse_mlp.experts.wi_gate):
|
||||
ep_rank = get_ep_rank(sparse_mlp.experts.wi_gate)
|
||||
expert_num = sparse_mlp.experts.wi_gate.shape[0]
|
||||
expert_slice = slice(ep_rank * expert_num, (ep_rank + 1) * expert_num)
|
||||
else:
|
||||
w1 = torch.cat([w1, wi_1], dim=0)
|
||||
if w2 is None:
|
||||
w2 = wi_2
|
||||
else:
|
||||
w2 = torch.cat([w2, wi_2], dim=0)
|
||||
if w3 is None:
|
||||
w3 = wi_3
|
||||
else:
|
||||
w3 = torch.cat([w3, wi_3], dim=0)
|
||||
expert_slice = slice(None)
|
||||
w1 = w1[expert_slice].clone().detach()
|
||||
w2 = w2[expert_slice].clone().detach()
|
||||
w3 = w3[expert_slice].clone().detach()
|
||||
assert (
|
||||
w1.shape == sparse_mlp.experts.wi_gate.shape
|
||||
), f"current shape: {w1.shape}, target shape:{sparse_mlp.experts.wi_gate.shape}"
|
||||
assert (
|
||||
w2.shape == sparse_mlp.experts.wo.shape
|
||||
), f"current shape: {w2.shape}, target shape:{sparse_mlp.experts.wo.shape}"
|
||||
assert (
|
||||
w3.shape == sparse_mlp.experts.wi_up.shape
|
||||
), f"current shape: {w3.shape}, target shape:{sparse_mlp.experts.wi_up.shape}"
|
||||
|
||||
sparse_mlp.experts.wi_gate.data = w1[:2]
|
||||
sparse_mlp.experts.wi_up.data = w3[:2]
|
||||
sparse_mlp.experts.wo.data = w2[:2]
|
||||
sparse_mlp.gate_weight = module.gate.weight
|
||||
# assign new param to colossal moe moudle
|
||||
sparse_mlp.experts.wi_gate.data = w1
|
||||
sparse_mlp.experts.wi_up.data = w3
|
||||
sparse_mlp.experts.wo.data = w2
|
||||
sparse_mlp.gate_weight = module.gate.weight
|
||||
|
||||
return sparse_mlp.to(dtype).to(device)
|
||||
# TODO: fix
|
||||
# the old weight is referenced somewhere so we can not del it.
|
||||
# Change data pointer of old weight to release memory.
|
||||
# The pointer will not be used and can be any pointer.
|
||||
for i in module.experts:
|
||||
i.w1.weight.data = w1
|
||||
i.w2.weight.data = w2
|
||||
i.w3.weight.data = w3
|
||||
|
||||
return sparse_mlp
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def fetch_requirements(path):
|
||||
with open(path, "r") as fd:
|
||||
return [r.strip() for r in fd.readlines()]
|
||||
|
||||
|
||||
def fetch_readme():
|
||||
with open("README.md", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def fetch_version():
|
||||
with open("version.txt", "r") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
setup(
|
||||
name="colossal_moe",
|
||||
version=fetch_version(),
|
||||
packages=find_packages(
|
||||
exclude=(
|
||||
"tests",
|
||||
"benchmarks",
|
||||
"*.egg-info",
|
||||
)
|
||||
),
|
||||
description="Colossal-AI MoE",
|
||||
long_description=fetch_readme(),
|
||||
long_description_content_type="text/markdown",
|
||||
license="Apache Software License 2.0",
|
||||
url="https://github.com/hpcaitech",
|
||||
install_requires=fetch_requirements("requirements.txt"),
|
||||
python_requires=">=3.6",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Environment :: GPU :: NVIDIA CUDA",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: System :: Distributed Computing",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,31 @@
|
|||
import copy
|
||||
|
||||
import torch
|
||||
from colossal_moe.models.mixtral_layer import MixtralSparseMLP
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_local_experts = num_local_experts
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.hidden_act = hidden_act
|
||||
|
||||
|
||||
def test_moe_layer():
|
||||
config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu")
|
||||
mistral_moe = MixtralSparseMoeBlock(config).cuda()
|
||||
colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda()
|
||||
|
||||
data = torch.randn(2, 8, 4).cuda()
|
||||
mistral_output = mistral_moe(data)[0]
|
||||
colossal_output = colossal_moe(data)[0]
|
||||
assert torch.allclose(
|
||||
mistral_output, colossal_output
|
||||
), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_layer()
|
|
@ -1,21 +1,20 @@
|
|||
import argparse
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, T5Tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.moe.layers import apply_load_balance
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe import MOE_MANAGER, apply_load_balance
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
|
@ -23,21 +22,6 @@ def move_to_cuda(batch, device):
|
|||
return {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
|
||||
def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict:
|
||||
texts = ["<pad>" + sample["prompt"] + sample["completion"] for sample in batch]
|
||||
data = tokenizer(
|
||||
texts,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
|
||||
self.num_samples = num_samples
|
||||
|
@ -188,7 +172,6 @@ def main():
|
|||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||
coordinator = DistCoordinator()
|
||||
test_mode = args.model_name == "test"
|
||||
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
|
@ -247,15 +230,20 @@ def main():
|
|||
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||
|
||||
# Build OpenMoe model
|
||||
config = MixtralConfig(
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
use_cache=False,
|
||||
)
|
||||
model = MixtralForCausalLM(config).bfloat16()
|
||||
# config = MixtralConfig(
|
||||
# hidden_size=2048,
|
||||
# intermediate_size=4096,
|
||||
# num_hidden_layers=4,
|
||||
# num_attention_heads=4,
|
||||
# num_key_value_heads=4,
|
||||
# use_cache=False,
|
||||
# )
|
||||
config = MixtralConfig.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
config.use_cache = False
|
||||
# config.num_local_experts = 1
|
||||
init_ctx = LazyInitContext(default_device=get_current_device())
|
||||
with init_ctx:
|
||||
model = MixtralForCausalLM(config).bfloat16()
|
||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||
|
||||
# Enable gradient checkpointing
|
||||
|
@ -270,7 +258,7 @@ def main():
|
|||
)
|
||||
|
||||
# Set optimizer
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
@ -292,7 +280,6 @@ def main():
|
|||
) as pbar:
|
||||
for step in pbar:
|
||||
if use_pipeline:
|
||||
exit()
|
||||
# Forward pass
|
||||
outputs = booster.execute_pipeline(
|
||||
train_dataloader_iter,
|
||||
|
@ -307,11 +294,9 @@ def main():
|
|||
loss = outputs["loss"]
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
else:
|
||||
print("1111111\n\n")
|
||||
# Forward pass
|
||||
data = next(train_dataloader_iter)
|
||||
data = move_to_cuda(data, torch.cuda.current_device())
|
||||
print(data)
|
||||
outputs = model(**data)
|
||||
loss = outputs["loss"]
|
||||
# Backward
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
1.0.0
|
|
@ -1,6 +1,7 @@
|
|||
from .checkpoint import MoECheckpintIO
|
||||
from .experts import MLPExperts
|
||||
from .layers import SparseMLP
|
||||
from .layers import SparseMLP, apply_load_balance
|
||||
from .manager import MOE_MANAGER
|
||||
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
|
||||
|
||||
|
@ -14,4 +15,6 @@ __all__ = [
|
|||
"UniformNoiseGenerator",
|
||||
"SparseMLP",
|
||||
"MoECheckpintIO",
|
||||
"MOE_MANAGER",
|
||||
"apply_load_balance",
|
||||
]
|
||||
|
|
|
@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
|
|||
self.ep_size = 1
|
||||
|
||||
if gated:
|
||||
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
|
||||
self.wi_gate = nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
|
||||
)
|
||||
)
|
||||
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||
else:
|
||||
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||
|
|
|
@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
|
|||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
router_top_k: int = 1,
|
||||
router_loss: bool = True,
|
||||
router_norm: bool = False,
|
||||
router_capacity_factor_train: float = 1.25,
|
||||
router_capacity_factor_eval: float = 2.0,
|
||||
router_min_capacity: int = 4,
|
||||
|
@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
|
|||
enable_kernel: bool = False,
|
||||
enable_comm_overlap: bool = False,
|
||||
enable_hierarchical_comm: bool = False,
|
||||
return_gate_logits: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_experts = num_experts
|
||||
self.gated = mlp_gated
|
||||
self.return_gate_logits = return_gate_logits
|
||||
self.enable_kernel = enable_kernel
|
||||
self.enable_comm_overlap = enable_comm_overlap
|
||||
self.expert_parallel = MOE_MANAGER.get_parallel()
|
||||
self.router_loss = router_loss
|
||||
self.router_norm = router_norm
|
||||
|
||||
# moe router
|
||||
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
|
||||
|
@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
|
|||
tokens = inputs.reshape(-1, self.hidden_size)
|
||||
|
||||
# the data type of the inputs in the gating should be fp32
|
||||
fp32_input = tokens.to(torch.float)
|
||||
fp32_weight = self.gate_weight.to(torch.float)
|
||||
gate_output = F.linear(fp32_input, fp32_weight)
|
||||
gate_logits = F.linear(tokens, self.gate_weight)
|
||||
gate_output = gate_logits.to(torch.float)
|
||||
|
||||
# update expert load
|
||||
if self.enable_load_balance == True:
|
||||
|
@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
|
|||
|
||||
# the result from the router
|
||||
used_capacity, *route_result_list = self.router(
|
||||
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
|
||||
inputs=gate_output,
|
||||
use_kernel=self.enable_kernel,
|
||||
ep_group=self.ep_group,
|
||||
use_loss=self.router_loss,
|
||||
use_norm=self.router_norm,
|
||||
)
|
||||
|
||||
# dispatch_data: (num_experts, capacity, hidden_size)
|
||||
if self.enable_kernel:
|
||||
|
@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
|
|||
|
||||
# expert_output: (num_groups, num_experts, capacity, hidden_size)
|
||||
if self.expert_parallel == "EP":
|
||||
expert_output = self._ep_process(
|
||||
dispatch_data,
|
||||
used_capacity,
|
||||
overlap=self.enable_comm_overlap
|
||||
)
|
||||
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||
elif self.expert_parallel == "TP":
|
||||
expert_output = self._tp_process(
|
||||
dispatch_data,
|
||||
used_capacity,
|
||||
overlap=self.enable_comm_overlap
|
||||
)
|
||||
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||
elif self.expert_parallel is None:
|
||||
expert_output = self._local_process(dispatch_data)
|
||||
else:
|
||||
raise NotImplementedError("This kind of communication has not been implemented yet.\n"
|
||||
"Please use Experts build function.")
|
||||
raise NotImplementedError(
|
||||
"This kind of communication has not been implemented yet.\n" "Please use Experts build function."
|
||||
)
|
||||
|
||||
if self.enable_kernel:
|
||||
expert_output = expert_output.reshape(-1, self.hidden_size)
|
||||
|
@ -204,7 +207,11 @@ class SparseMLP(nn.Module):
|
|||
ans = torch.matmul(combine_weights, expert_output)
|
||||
|
||||
ans = ans.reshape(inputs.shape)
|
||||
return ans
|
||||
|
||||
if self.return_gate_logits:
|
||||
return ans, gate_logits
|
||||
else:
|
||||
return ans
|
||||
|
||||
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
|
||||
expert_in = expert_in.unsqueeze(0)
|
||||
|
@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
|
|||
return expert_out
|
||||
|
||||
def _ep_process(
|
||||
self,
|
||||
dispatch_data: torch.Tensor,
|
||||
used_capacity: torch.Tensor,
|
||||
overlap: bool = False
|
||||
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Expert Parallel
|
||||
|
@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
|
|||
"""
|
||||
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
||||
if self.ep_hierarchical_group is not None:
|
||||
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
||||
expert_input = HierarchicalAllToAll.apply(
|
||||
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
|
||||
)
|
||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||
expert_output = self.experts(expert_input)
|
||||
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
||||
expert_output = HierarchicalAllToAll.apply(
|
||||
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
|
||||
)
|
||||
return expert_output
|
||||
else:
|
||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
||||
|
@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
|
|||
NUM_CHUNK = 4
|
||||
NUM_STAGES = 4
|
||||
|
||||
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
|
||||
assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
|
||||
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
|
||||
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||
dispatch_data = dispatch_data.reshape(*input_shape)
|
||||
|
@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
|
|||
for i in range(NUM_CHUNK + NUM_STAGES - 1):
|
||||
if expert_out is not None:
|
||||
expert_out.handle.wait()
|
||||
output[:, :, offset:offset + chunk_size, :] = expert_out.data
|
||||
output[:, :, offset : offset + chunk_size, :] = expert_out.data
|
||||
offset += chunk_size
|
||||
expert_out = None
|
||||
|
||||
# all2all last output
|
||||
if _expert_out is not None:
|
||||
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
|
||||
expert_out = Capsule(
|
||||
*AllToAll.apply(_expert_out.data, self.ep_group, True),
|
||||
)
|
||||
_expert_out = None
|
||||
|
||||
# all2all next input
|
||||
|
@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
|
|||
return output
|
||||
|
||||
def _tp_process(
|
||||
self,
|
||||
dispatch_data: torch.Tensor,
|
||||
used_capacity: torch.Tensor,
|
||||
overlap: bool = False
|
||||
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
without overlap:
|
||||
|
@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
|
|||
NUM_CHUNK = 4
|
||||
NUM_STAGES = 4
|
||||
|
||||
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
|
||||
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
||||
assert (
|
||||
dispatch_data.shape[0] % NUM_CHUNK == 0
|
||||
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
||||
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
|
||||
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
|
||||
output = torch.empty_like(dispatch_data)
|
||||
|
|
|
@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC):
|
|||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
k_value: int,
|
||||
capacity_factor_train: float,
|
||||
capacity_factor_eval: float,
|
||||
min_capacity: int,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
use_kernel: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
k_value: int,
|
||||
capacity_factor_train: float,
|
||||
capacity_factor_eval: float,
|
||||
min_capacity: int,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
use_kernel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.k_value = k_value
|
||||
self.capacity_factor_train = capacity_factor_train
|
||||
|
@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC):
|
|||
if router_probs.dim() == expert_indices.dim() == 2:
|
||||
router_probs = router_probs.unsqueeze(0)
|
||||
expert_indices = expert_indices.unsqueeze(0)
|
||||
assert router_probs.dim() == expert_indices.dim() == 3, \
|
||||
"router_probs must be 3D tensor and expert_indices must be 4D tensor"
|
||||
assert (
|
||||
router_probs.dim() == expert_indices.dim() == 3
|
||||
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
|
||||
|
||||
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
|
||||
expert_mask = F.one_hot(expert_indices, num_experts)
|
||||
|
@ -122,25 +125,28 @@ class Top1Router(MoeRouter):
|
|||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
select_policy: str = "first",
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__(k_value=1,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
def __init__(
|
||||
self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
select_policy: str = "first",
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
k_value=1,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
self.select_policy = select_policy
|
||||
assert select_policy in {"first", "random"}
|
||||
if select_policy == "random":
|
||||
self.uniform = torch.distributions.uniform.Uniform(
|
||||
low=torch.tensor(0.0, device=get_current_device()),
|
||||
high=torch.tensor(1.0, device=get_current_device())
|
||||
low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device())
|
||||
).rsample
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
||||
|
@ -200,7 +206,7 @@ class Top1Router(MoeRouter):
|
|||
weight = mask * probs.type_as(inputs)
|
||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||
sec_mask = combine_weights.bool()
|
||||
return used_capacity, combine_weights, sec_mask
|
||||
return used_capacity, combine_weights, sec_mask, probs
|
||||
|
||||
|
||||
class Top2Router(MoeRouter):
|
||||
|
@ -216,20 +222,31 @@ class Top2Router(MoeRouter):
|
|||
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__(k_value=2,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
def __init__(
|
||||
self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
k_value=2,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
use_kernel: bool = False,
|
||||
ep_group: Optional[ProcessGroup] = None,
|
||||
use_norm: bool = False,
|
||||
use_loss: bool = True,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||
|
@ -246,6 +263,10 @@ class Top2Router(MoeRouter):
|
|||
|
||||
assert inputs.dtype == torch.float
|
||||
probs = F.softmax(inputs, dim=-1)
|
||||
if use_norm:
|
||||
routing_weights, _ = torch.topk(probs, 2, dim=-1)
|
||||
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
num_experts = probs.size(-1)
|
||||
capacity = self.get_capacity(inputs.shape)
|
||||
|
||||
|
@ -255,21 +276,22 @@ class Top2Router(MoeRouter):
|
|||
top2_idx = torch.argmax(logits_except1, dim=-1)
|
||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
cmask = (mask1 + mask2) # loss: [s, e]
|
||||
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
|
||||
cmask = mask1 + mask2 # loss: [s, e]
|
||||
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
|
||||
|
||||
# calculate loss
|
||||
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
||||
self.set_aux_loss(probs, expert_indices, num_experts)
|
||||
self.set_z_loss(inputs)
|
||||
self.pop_router_loss()
|
||||
if use_loss:
|
||||
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
||||
self.set_aux_loss(probs, expert_indices, num_experts)
|
||||
self.set_z_loss(inputs)
|
||||
self.pop_router_loss()
|
||||
|
||||
if not self.training and not self.drop_tks and ep_group is not None:
|
||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
|
||||
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
|
||||
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
|
||||
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
||||
|
||||
|
@ -336,15 +358,18 @@ class TopKRouter(MoeRouter):
|
|||
oversubscribed / reach capacity.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_selected_experts: int,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func,
|
||||
drop_tks)
|
||||
def __init__(
|
||||
self,
|
||||
num_selected_experts: int,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -410,7 +435,7 @@ class TopKRouter(MoeRouter):
|
|||
# The combine array will be used for combining expert outputs, scaled by the
|
||||
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
|
||||
# expert_capacity].
|
||||
combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask)
|
||||
combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
|
||||
|
||||
return combine_array, dispatch_mask
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@ from colossalai.utils import get_current_device
|
|||
|
||||
|
||||
class ForceFP32Parameter(torch.nn.Parameter):
|
||||
|
||||
def half(self, memory_format=None):
|
||||
return self.data.clone()
|
||||
|
||||
|
@ -84,6 +83,8 @@ def get_activation(act: str) -> Callable:
|
|||
return torch.nn.GELU()
|
||||
elif act == "swiglu":
|
||||
return SwiGLU
|
||||
elif act == "silu":
|
||||
return torch.nn.SiLU()
|
||||
else:
|
||||
raise NotImplementedError("Unsupported activation function")
|
||||
|
||||
|
@ -142,7 +143,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
|
|||
epsize_param_dict = dict()
|
||||
for param in model.parameters():
|
||||
if not is_moe_tensor(param):
|
||||
ep_size = 1 # set ep_size to 1 for dp parameters
|
||||
ep_size = 1 # set ep_size to 1 for dp parameters
|
||||
else:
|
||||
ep_size = get_ep_size(param)
|
||||
if ep_size not in epsize_param_dict:
|
||||
|
@ -193,18 +194,13 @@ def create_ep_hierarchical_group(
|
|||
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
|
||||
nproc_per_node = int(nproc_per_node)
|
||||
else:
|
||||
assert dist.get_world_size() % nproc_per_node == 0, \
|
||||
"nproc_per_node should be a divisor of world_size."
|
||||
assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size."
|
||||
num_node = dist.get_world_size() // nproc_per_node
|
||||
|
||||
intra_src_rank = None
|
||||
ep_intra_node_group = None
|
||||
for i in range(num_node):
|
||||
ep_intra_ranks = [
|
||||
i * nproc_per_node + j
|
||||
for j in range(nproc_per_node)
|
||||
if j in ep_group_ranks
|
||||
]
|
||||
ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks]
|
||||
group = dist.new_group(ep_intra_ranks)
|
||||
if rank in ep_intra_ranks:
|
||||
assert ep_intra_node_group is None
|
||||
|
@ -212,10 +208,7 @@ def create_ep_hierarchical_group(
|
|||
intra_src_rank = ep_intra_ranks[0]
|
||||
|
||||
ep_inter_node_group = None
|
||||
ep_inter_ranks = [
|
||||
ep_group_ranks[0] + i * nproc_per_node
|
||||
for i in range(num_node)
|
||||
]
|
||||
ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)]
|
||||
if len(ep_inter_ranks) > 1:
|
||||
group = dist.new_group(ep_inter_ranks)
|
||||
if rank in ep_inter_ranks:
|
||||
|
|
Loading…
Reference in New Issue