From f66469e20986f656ffed29eff97bb2987bf88fef Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Fri, 15 Dec 2023 16:32:32 +0800 Subject: [PATCH] update --- .../colossal_moe/models/mixtral_layer.py | 123 +++++++++------- applications/ColossalMoE/setup.py | 43 ++++++ applications/ColossalMoE/tests/__init__.py | 0 .../ColossalMoE/tests/test_moe_layer.py | 31 ++++ applications/ColossalMoE/train.py | 53 +++---- applications/ColossalMoE/version.txt | 1 + colossalai/moe/__init__.py | 5 +- colossalai/moe/experts.py | 6 +- colossalai/moe/layers.py | 72 +++++---- colossalai/moe/routers.py | 137 +++++++++++------- colossalai/moe/utils.py | 19 +-- 11 files changed, 304 insertions(+), 186 deletions(-) create mode 100644 applications/ColossalMoE/setup.py create mode 100644 applications/ColossalMoE/tests/__init__.py create mode 100644 applications/ColossalMoE/tests/test_moe_layer.py create mode 100644 applications/ColossalMoE/version.txt diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py index 71ee5ff1b..fb5cfc588 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py @@ -4,6 +4,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from colossalai.lazy import LazyInitContext from colossalai.moe import SparseMLP +from colossalai.tensor.moe_tensor.api import get_ep_rank, is_moe_tensor class MixtralSparseMLP: @@ -33,57 +34,81 @@ class MixtralSparseMLP: Raises: AssertionError: If the provided module is not an instance of nn.LayerNorm. """ + with torch.no_grad(): + LazyInitContext.materialize(module) - LazyInitContext.materialize(module) - # get the attributes of the module - moe_kwargs = dict( - num_experts=module.num_experts, - hidden_size=module.hidden_dim, - intermediate_size=module.ffn_dim, - router_top_k=module.top_k, - # router_capacity_factor_train = module. - # router_capacity_factor_eval = module. - # router_min_capacity = module. - # router_noisy_policy = module. - # router_drop_tks = module. - mlp_activation="silu", - mlp_gated=True, - # enable_load_balance = module. - # load_balance_tolerance = module. - # load_balance_beam_width = module. - # load_balance_group_swap_factor = module. - # enable_kernel = module. - # enable_comm_overlap = module. - # enable_hierarchical_comm = module. - return_gate_logits=True, - ) - dtype = module.gate.weight.dtype - device = module.gate.weight.device + # get the attributes of the module + moe_kwargs = dict( + num_experts=module.num_experts, + hidden_size=module.hidden_dim, + intermediate_size=module.ffn_dim, + router_top_k=module.top_k, + router_norm=True, + router_loss=False, + # router_capacity_factor_train = . + # router_capacity_factor_eval = . + mlp_activation="silu", + mlp_gated=True, + # enable_load_balance = . + # load_balance_tolerance = . + # load_balance_beam_width = . + # load_balance_group_swap_factor = . + # enable_kernel = . + # enable_comm_overlap = . + # enable_hierarchical_comm = . + return_gate_logits=True, + ) + dtype = module.gate.weight.dtype + device = module.gate.weight.device + sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device) - sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device) - w1 = None - w2 = None - w3 = None - for i in module.experts: - wi_1 = i.w1.weight.data.transpose(0, 1).unsqueeze(0) - wi_2 = i.w2.weight.data.transpose(0, 1).unsqueeze(0) - wi_3 = i.w3.weight.data.transpose(0, 1).unsqueeze(0) - if w1 is None: - w1 = wi_1 + # cat all experts + w1 = None + w2 = None + w3 = None + for i in module.experts: + # origin + wi_1 = i.w1.weight.data.clone().transpose(0, 1).unsqueeze(0) + wi_2 = i.w2.weight.data.clone().transpose(0, 1).unsqueeze(0) + wi_3 = i.w3.weight.data.clone().transpose(0, 1).unsqueeze(0) + # cat + w1 = wi_1 if w1 is None else torch.cat([w1, wi_1], dim=0) + w2 = wi_2 if w2 is None else torch.cat([w2, wi_2], dim=0) + w3 = wi_3 if w3 is None else torch.cat([w3, wi_3], dim=0) + + # get local experts + if is_moe_tensor(sparse_mlp.experts.wi_gate): + ep_rank = get_ep_rank(sparse_mlp.experts.wi_gate) + expert_num = sparse_mlp.experts.wi_gate.shape[0] + expert_slice = slice(ep_rank * expert_num, (ep_rank + 1) * expert_num) else: - w1 = torch.cat([w1, wi_1], dim=0) - if w2 is None: - w2 = wi_2 - else: - w2 = torch.cat([w2, wi_2], dim=0) - if w3 is None: - w3 = wi_3 - else: - w3 = torch.cat([w3, wi_3], dim=0) + expert_slice = slice(None) + w1 = w1[expert_slice].clone().detach() + w2 = w2[expert_slice].clone().detach() + w3 = w3[expert_slice].clone().detach() + assert ( + w1.shape == sparse_mlp.experts.wi_gate.shape + ), f"current shape: {w1.shape}, target shape:{sparse_mlp.experts.wi_gate.shape}" + assert ( + w2.shape == sparse_mlp.experts.wo.shape + ), f"current shape: {w2.shape}, target shape:{sparse_mlp.experts.wo.shape}" + assert ( + w3.shape == sparse_mlp.experts.wi_up.shape + ), f"current shape: {w3.shape}, target shape:{sparse_mlp.experts.wi_up.shape}" - sparse_mlp.experts.wi_gate.data = w1[:2] - sparse_mlp.experts.wi_up.data = w3[:2] - sparse_mlp.experts.wo.data = w2[:2] - sparse_mlp.gate_weight = module.gate.weight + # assign new param to colossal moe moudle + sparse_mlp.experts.wi_gate.data = w1 + sparse_mlp.experts.wi_up.data = w3 + sparse_mlp.experts.wo.data = w2 + sparse_mlp.gate_weight = module.gate.weight - return sparse_mlp.to(dtype).to(device) + # TODO: fix + # the old weight is referenced somewhere so we can not del it. + # Change data pointer of old weight to release memory. + # The pointer will not be used and can be any pointer. + for i in module.experts: + i.w1.weight.data = w1 + i.w2.weight.data = w2 + i.w3.weight.data = w3 + + return sparse_mlp diff --git a/applications/ColossalMoE/setup.py b/applications/ColossalMoE/setup.py new file mode 100644 index 000000000..275f59e10 --- /dev/null +++ b/applications/ColossalMoE/setup.py @@ -0,0 +1,43 @@ +from setuptools import find_packages, setup + + +def fetch_requirements(path): + with open(path, "r") as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme(): + with open("README.md", encoding="utf-8") as f: + return f.read() + + +def fetch_version(): + with open("version.txt", "r") as f: + return f.read().strip() + + +setup( + name="colossal_moe", + version=fetch_version(), + packages=find_packages( + exclude=( + "tests", + "benchmarks", + "*.egg-info", + ) + ), + description="Colossal-AI MoE", + long_description=fetch_readme(), + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://github.com/hpcaitech", + install_requires=fetch_requirements("requirements.txt"), + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", + ], +) diff --git a/applications/ColossalMoE/tests/__init__.py b/applications/ColossalMoE/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/ColossalMoE/tests/test_moe_layer.py b/applications/ColossalMoE/tests/test_moe_layer.py new file mode 100644 index 000000000..8b090c427 --- /dev/null +++ b/applications/ColossalMoE/tests/test_moe_layer.py @@ -0,0 +1,31 @@ +import copy + +import torch +from colossal_moe.models.mixtral_layer import MixtralSparseMLP +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + +class Config: + def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_local_experts = num_local_experts + self.num_experts_per_tok = num_experts_per_tok + self.hidden_act = hidden_act + + +def test_moe_layer(): + config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu") + mistral_moe = MixtralSparseMoeBlock(config).cuda() + colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda() + + data = torch.randn(2, 8, 4).cuda() + mistral_output = mistral_moe(data)[0] + colossal_output = colossal_moe(data)[0] + assert torch.allclose( + mistral_output, colossal_output + ), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}" + + +if __name__ == "__main__": + test_moe_layer() diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py index 32e9c1e15..9c059c7b8 100644 --- a/applications/ColossalMoE/train.py +++ b/applications/ColossalMoE/train.py @@ -1,21 +1,20 @@ import argparse -from typing import Dict import torch import torch.distributed as dist from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy from torch.utils.data import Dataset from tqdm import tqdm -from transformers import AutoTokenizer, T5Tokenizer +from transformers import AutoTokenizer from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM -from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM import colossalai from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator -from colossalai.moe.layers import apply_load_balance -from colossalai.moe.manager import MOE_MANAGER +from colossalai.lazy import LazyInitContext +from colossalai.moe import MOE_MANAGER, apply_load_balance +from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -23,21 +22,6 @@ def move_to_cuda(batch, device): return {k: v.to(device) for k, v in batch.items()} -def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict: - texts = ["" + sample["prompt"] + sample["completion"] for sample in batch] - data = tokenizer( - texts, - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=max_length, - add_special_tokens=False, - ) - data = {k: v.cuda() for k, v in data.items()} - data["labels"] = data["input_ids"].clone() - return data - - class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None): self.num_samples = num_samples @@ -188,7 +172,6 @@ def main(): # Launch ColossalAI colossalai.launch_from_torch(config={}, seed=args.seed) coordinator = DistCoordinator() - test_mode = args.model_name == "test" # Set plugin booster_kwargs = {} @@ -247,15 +230,20 @@ def main(): coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") # Build OpenMoe model - config = MixtralConfig( - hidden_size=32, - intermediate_size=64, - num_hidden_layers=4, - num_attention_heads=4, - num_key_value_heads=4, - use_cache=False, - ) - model = MixtralForCausalLM(config).bfloat16() + # config = MixtralConfig( + # hidden_size=2048, + # intermediate_size=4096, + # num_hidden_layers=4, + # num_attention_heads=4, + # num_key_value_heads=4, + # use_cache=False, + # ) + config = MixtralConfig.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1") + config.use_cache = False + # config.num_local_experts = 1 + init_ctx = LazyInitContext(default_device=get_current_device()) + with init_ctx: + model = MixtralForCausalLM(config).bfloat16() coordinator.print_on_master(f"Finish init model with config:\n{config}") # Enable gradient checkpointing @@ -270,7 +258,7 @@ def main(): ) # Set optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # Set booster booster = Booster(plugin=plugin, **booster_kwargs) @@ -292,7 +280,6 @@ def main(): ) as pbar: for step in pbar: if use_pipeline: - exit() # Forward pass outputs = booster.execute_pipeline( train_dataloader_iter, @@ -307,11 +294,9 @@ def main(): loss = outputs["loss"] pbar.set_postfix({"loss": loss.item()}) else: - print("1111111\n\n") # Forward pass data = next(train_dataloader_iter) data = move_to_cuda(data, torch.cuda.current_device()) - print(data) outputs = model(**data) loss = outputs["loss"] # Backward diff --git a/applications/ColossalMoE/version.txt b/applications/ColossalMoE/version.txt new file mode 100644 index 000000000..3eefcb9dd --- /dev/null +++ b/applications/ColossalMoE/version.txt @@ -0,0 +1 @@ +1.0.0 diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index 721da69d0..6dd0a5fc3 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,6 +1,7 @@ from .checkpoint import MoECheckpintIO from .experts import MLPExperts -from .layers import SparseMLP +from .layers import SparseMLP, apply_load_balance +from .manager import MOE_MANAGER from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .utils import NormalNoiseGenerator, UniformNoiseGenerator @@ -14,4 +15,6 @@ __all__ = [ "UniformNoiseGenerator", "SparseMLP", "MoECheckpintIO", + "MOE_MANAGER", + "apply_load_balance", ] diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 477b76547..8e6ea3884 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -67,7 +67,11 @@ class MLPExperts(nn.Module): self.ep_size = 1 if gated: - self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2)) + self.wi_gate = nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size + ) + ) self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) else: self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index b768fb94a..2ac5b186d 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -51,6 +51,8 @@ class SparseMLP(nn.Module): hidden_size: int, intermediate_size: int, router_top_k: int = 1, + router_loss: bool = True, + router_norm: bool = False, router_capacity_factor_train: float = 1.25, router_capacity_factor_eval: float = 2.0, router_min_capacity: int = 4, @@ -65,15 +67,19 @@ class SparseMLP(nn.Module): enable_kernel: bool = False, enable_comm_overlap: bool = False, enable_hierarchical_comm: bool = False, + return_gate_logits: bool = False, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_experts = num_experts self.gated = mlp_gated + self.return_gate_logits = return_gate_logits self.enable_kernel = enable_kernel self.enable_comm_overlap = enable_comm_overlap self.expert_parallel = MOE_MANAGER.get_parallel() + self.router_loss = router_loss + self.router_norm = router_norm # moe router noisy_func = get_noise_generator(router_noisy_policy, num_experts) @@ -150,9 +156,8 @@ class SparseMLP(nn.Module): tokens = inputs.reshape(-1, self.hidden_size) # the data type of the inputs in the gating should be fp32 - fp32_input = tokens.to(torch.float) - fp32_weight = self.gate_weight.to(torch.float) - gate_output = F.linear(fp32_input, fp32_weight) + gate_logits = F.linear(tokens, self.gate_weight) + gate_output = gate_logits.to(torch.float) # update expert load if self.enable_load_balance == True: @@ -165,7 +170,12 @@ class SparseMLP(nn.Module): # the result from the router used_capacity, *route_result_list = self.router( - inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) + inputs=gate_output, + use_kernel=self.enable_kernel, + ep_group=self.ep_group, + use_loss=self.router_loss, + use_norm=self.router_norm, + ) # dispatch_data: (num_experts, capacity, hidden_size) if self.enable_kernel: @@ -177,22 +187,15 @@ class SparseMLP(nn.Module): # expert_output: (num_groups, num_experts, capacity, hidden_size) if self.expert_parallel == "EP": - expert_output = self._ep_process( - dispatch_data, - used_capacity, - overlap=self.enable_comm_overlap - ) + expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) elif self.expert_parallel == "TP": - expert_output = self._tp_process( - dispatch_data, - used_capacity, - overlap=self.enable_comm_overlap - ) + expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) elif self.expert_parallel is None: expert_output = self._local_process(dispatch_data) else: - raise NotImplementedError("This kind of communication has not been implemented yet.\n" - "Please use Experts build function.") + raise NotImplementedError( + "This kind of communication has not been implemented yet.\n" "Please use Experts build function." + ) if self.enable_kernel: expert_output = expert_output.reshape(-1, self.hidden_size) @@ -204,7 +207,11 @@ class SparseMLP(nn.Module): ans = torch.matmul(combine_weights, expert_output) ans = ans.reshape(inputs.shape) - return ans + + if self.return_gate_logits: + return ans, gate_logits + else: + return ans def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_in = expert_in.unsqueeze(0) @@ -212,10 +219,7 @@ class SparseMLP(nn.Module): return expert_out def _ep_process( - self, - dispatch_data: torch.Tensor, - used_capacity: torch.Tensor, - overlap: bool = False + self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False ) -> torch.Tensor: """ Expert Parallel @@ -228,10 +232,14 @@ class SparseMLP(nn.Module): """ if not overlap or dist.get_world_size(self.ep_group) == 1: if self.ep_hierarchical_group is not None: - expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank) + expert_input = HierarchicalAllToAll.apply( + dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank + ) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_output = self.experts(expert_input) - expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank) + expert_output = HierarchicalAllToAll.apply( + expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank + ) return expert_output else: expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] @@ -249,7 +257,7 @@ class SparseMLP(nn.Module): NUM_CHUNK = 4 NUM_STAGES = 4 - assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" + assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet" chunk_size = dispatch_data.shape[1] // NUM_CHUNK input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) dispatch_data = dispatch_data.reshape(*input_shape) @@ -262,13 +270,15 @@ class SparseMLP(nn.Module): for i in range(NUM_CHUNK + NUM_STAGES - 1): if expert_out is not None: expert_out.handle.wait() - output[:, :, offset:offset + chunk_size, :] = expert_out.data + output[:, :, offset : offset + chunk_size, :] = expert_out.data offset += chunk_size expert_out = None # all2all last output if _expert_out is not None: - expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) + expert_out = Capsule( + *AllToAll.apply(_expert_out.data, self.ep_group, True), + ) _expert_out = None # all2all next input @@ -288,10 +298,7 @@ class SparseMLP(nn.Module): return output def _tp_process( - self, - dispatch_data: torch.Tensor, - used_capacity: torch.Tensor, - overlap: bool = False + self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False ) -> torch.Tensor: """ without overlap: @@ -326,8 +333,9 @@ class SparseMLP(nn.Module): NUM_CHUNK = 4 NUM_STAGES = 4 - assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ - "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + assert ( + dispatch_data.shape[0] % NUM_CHUNK == 0 + ), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_data = torch.split(dispatch_data, chunk_size, dim=0) output = torch.empty_like(dispatch_data) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index c5bb50862..a891ee7df 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC): drop_tks (bool, optional): Whether drops tokens in evaluation """ - def __init__(self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True, - use_kernel: bool = False): + def __init__( + self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + use_kernel: bool = False, + ): super().__init__() self.k_value = k_value self.capacity_factor_train = capacity_factor_train @@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC): if router_probs.dim() == expert_indices.dim() == 2: router_probs = router_probs.unsqueeze(0) expert_indices = expert_indices.unsqueeze(0) - assert router_probs.dim() == expert_indices.dim() == 3, \ - "router_probs must be 3D tensor and expert_indices must be 4D tensor" + assert ( + router_probs.dim() == expert_indices.dim() == 3 + ), "router_probs must be 3D tensor and expert_indices must be 4D tensor" # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. expert_mask = F.one_hot(expert_indices, num_experts) @@ -122,25 +125,28 @@ class Top1Router(MoeRouter): drop_tks (bool, optional): Whether drops tokens in evaluation """ - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) self.select_policy = select_policy assert select_policy in {"first", "random"} if select_policy == "random": self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, device=get_current_device()) + low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device()) ).rsample def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: @@ -200,7 +206,7 @@ class Top1Router(MoeRouter): weight = mask * probs.type_as(inputs) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) sec_mask = combine_weights.bool() - return used_capacity, combine_weights, sec_mask + return used_capacity, combine_weights, sec_mask, probs class Top2Router(MoeRouter): @@ -216,20 +222,31 @@ class Top2Router(MoeRouter): drop_tks (bool, optional): Whether drops tokens in evaluation. """ - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: + def forward( + self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None, + use_norm: bool = False, + use_loss: bool = True, + ) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). @@ -246,6 +263,10 @@ class Top2Router(MoeRouter): assert inputs.dtype == torch.float probs = F.softmax(inputs, dim=-1) + if use_norm: + routing_weights, _ = torch.topk(probs, 2, dim=-1) + probs = probs / routing_weights.sum(dim=-1, keepdim=True) + num_experts = probs.size(-1) capacity = self.get_capacity(inputs.shape) @@ -255,21 +276,22 @@ class Top2Router(MoeRouter): top2_idx = torch.argmax(logits_except1, dim=-1) mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - cmask = (mask1 + mask2) # loss: [s, e] - cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 + cmask = mask1 + mask2 # loss: [s, e] + cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 # calculate loss - expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) - self.set_aux_loss(probs, expert_indices, num_experts) - self.set_z_loss(inputs) - self.pop_router_loss() + if use_loss: + expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) + self.set_aux_loss(probs, expert_indices, num_experts) + self.set_z_loss(inputs) + self.pop_router_loss() if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(cmask, dim=0)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() - rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] + rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) rank2 += torch.sum(mask1, dim=-2, keepdim=True) @@ -336,15 +358,18 @@ class TopKRouter(MoeRouter): oversubscribed / reach capacity. """ - def __init__(self, - num_selected_experts: int, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, - drop_tks) + def __init__( + self, + num_selected_experts: int, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks + ) def forward( self, @@ -410,7 +435,7 @@ class TopKRouter(MoeRouter): # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # expert_capacity]. - combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask) + combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) return combine_array, dispatch_mask diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index 5a17a6e0d..b5d62dd70 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -13,7 +13,6 @@ from colossalai.utils import get_current_device class ForceFP32Parameter(torch.nn.Parameter): - def half(self, memory_format=None): return self.data.clone() @@ -84,6 +83,8 @@ def get_activation(act: str) -> Callable: return torch.nn.GELU() elif act == "swiglu": return SwiGLU + elif act == "silu": + return torch.nn.SiLU() else: raise NotImplementedError("Unsupported activation function") @@ -142,7 +143,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]] epsize_param_dict = dict() for param in model.parameters(): if not is_moe_tensor(param): - ep_size = 1 # set ep_size to 1 for dp parameters + ep_size = 1 # set ep_size to 1 for dp parameters else: ep_size = get_ep_size(param) if ep_size not in epsize_param_dict: @@ -193,18 +194,13 @@ def create_ep_hierarchical_group( assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually." nproc_per_node = int(nproc_per_node) else: - assert dist.get_world_size() % nproc_per_node == 0, \ - "nproc_per_node should be a divisor of world_size." + assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size." num_node = dist.get_world_size() // nproc_per_node intra_src_rank = None ep_intra_node_group = None for i in range(num_node): - ep_intra_ranks = [ - i * nproc_per_node + j - for j in range(nproc_per_node) - if j in ep_group_ranks - ] + ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks] group = dist.new_group(ep_intra_ranks) if rank in ep_intra_ranks: assert ep_intra_node_group is None @@ -212,10 +208,7 @@ def create_ep_hierarchical_group( intra_src_rank = ep_intra_ranks[0] ep_inter_node_group = None - ep_inter_ranks = [ - ep_group_ranks[0] + i * nproc_per_node - for i in range(num_node) - ] + ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)] if len(ep_inter_ranks) > 1: group = dist.new_group(ep_inter_ranks) if rank in ep_inter_ranks: