|
|
|
@ -14,18 +14,23 @@ from transformers.models.mixtral.modeling_mixtral import (
|
|
|
|
|
from transformers.utils import is_flash_attn_2_available, logging |
|
|
|
|
|
|
|
|
|
from colossalai.lazy import LazyInitContext |
|
|
|
|
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven, drop_tokens, gather_tokens |
|
|
|
|
from colossalai.moe._operation import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, drop_tokens, gather_tokens |
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager |
|
|
|
|
from colossalai.shardformer.shard import ShardConfig |
|
|
|
|
from colossalai.shardformer.shard.utils import set_tensors_to_none |
|
|
|
|
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): |
|
|
|
|
def __init__(self, config, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None): |
|
|
|
|
super().__init__(config) |
|
|
|
|
self.setup_process_groups(ep_group, tp_group, moe_tp_group) |
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
|
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") |
|
|
|
|
|
|
|
|
|
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup): |
|
|
|
|
assert tp_group is not None |
|
|
|
|
assert moe_dp_group is not None |
|
|
|
|
assert ep_group is not None |
|
|
|
|
assert moe_tp_group is not None |
|
|
|
|
|
|
|
|
|
def setup_process_groups(self, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None): |
|
|
|
|
# setup ep group |
|
|
|
|
self.ep_size = dist.get_world_size(ep_group) |
|
|
|
|
self.ep_rank = dist.get_rank(ep_group) |
|
|
|
@ -40,7 +45,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|
|
|
|
|
|
|
|
|
set_tensors_to_none(self.experts, exclude=set(held_experts)) |
|
|
|
|
for p in self.experts.parameters(): |
|
|
|
|
p.ep_group = ep_group |
|
|
|
|
set_moe_tensor_ep_group(p, ep_group) |
|
|
|
|
|
|
|
|
|
# setup moe_dp group |
|
|
|
|
self.moe_dp_group = moe_dp_group |
|
|
|
|
self.moe_dp_size = moe_dp_group.size() |
|
|
|
|
|
|
|
|
|
# setup global tp group |
|
|
|
|
self.tp_group = tp_group |
|
|
|
@ -50,11 +59,12 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def from_native_module( |
|
|
|
|
module: MixtralSparseMoeBlock, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None, *args, **kwargs |
|
|
|
|
module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup, *args, **kwargs |
|
|
|
|
) -> "EPMixtralSparseMoeBlock": |
|
|
|
|
# TODO: better init |
|
|
|
|
LazyInitContext.materialize(module) |
|
|
|
|
module.__class__ = EPMixtralSparseMoeBlock |
|
|
|
|
module.setup_process_groups(ep_group) |
|
|
|
|
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group) |
|
|
|
|
return module |
|
|
|
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
|
@ -76,36 +86,48 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|
|
|
|
output_split_sizes = torch.zeros_like(input_split_sizes) |
|
|
|
|
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) |
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone() |
|
|
|
|
for i in range(1, self.ep_size): |
|
|
|
|
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep] |
|
|
|
|
activate_experts = (activate_experts > 0).float() |
|
|
|
|
dist.all_reduce(activate_experts, group=self.moe_dp_group) |
|
|
|
|
|
|
|
|
|
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() |
|
|
|
|
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() |
|
|
|
|
|
|
|
|
|
if self.tp_group is not None and self.tp_group.size() > 1: |
|
|
|
|
if self.tp_group.size() > 1: |
|
|
|
|
dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group) |
|
|
|
|
|
|
|
|
|
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) |
|
|
|
|
# compute expert output |
|
|
|
|
output_states = MoeInGradScaler.apply(output_states, self.ep_size) |
|
|
|
|
output_states = EPGradScalerIn.apply(output_states, self.ep_size) |
|
|
|
|
if output_states.size(0) > 0: |
|
|
|
|
if self.num_experts_per_ep == 1: |
|
|
|
|
# no need to split |
|
|
|
|
expert = self.experts[self.expert_start_idx] |
|
|
|
|
output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0].item()) |
|
|
|
|
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states) |
|
|
|
|
output_states = expert.w2(output_states) |
|
|
|
|
output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0].item()) |
|
|
|
|
else: |
|
|
|
|
output_states_splits = output_states.split(output_split_sizes.tolist()) |
|
|
|
|
output_states_list = [] |
|
|
|
|
for i, split_states in enumerate(output_states_splits): |
|
|
|
|
if split_states.size(0) == 0: |
|
|
|
|
continue |
|
|
|
|
split_states = DPGradScalerIn.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()) |
|
|
|
|
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] |
|
|
|
|
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) |
|
|
|
|
split_states = expert.w2(split_states) |
|
|
|
|
split_states = DPGradScalerOut.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()) |
|
|
|
|
output_states_list.append(split_states) |
|
|
|
|
output_states = torch.cat(output_states_list) |
|
|
|
|
output_states = MoeOutGradScaler.apply(output_states, self.ep_size) |
|
|
|
|
|
|
|
|
|
output_states = EPGradScalerOut.apply(output_states, self.ep_size) |
|
|
|
|
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) |
|
|
|
|
|
|
|
|
|
if self.tp_group is not None and self.tp_group.size() > 1: |
|
|
|
|
if self.tp_group.size() > 1: |
|
|
|
|
dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group) |
|
|
|
|
|
|
|
|
|
recover_experts_idx = torch.empty_like(selected_experts_idx) |
|
|
|
|