Browse Source

[moe] add mixtral dp grad scaling when not all experts are activated

colossalchat
botbw 4 months ago committed by Hongxin Liu
parent
commit
9b9b76bdcd
  1. 6
      colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
  2. 51
      colossalai/moe/_operation.py
  3. 6
      colossalai/shardformer/layer/moe/experts.py
  4. 6
      colossalai/shardformer/layer/moe/routers.py
  5. 46
      colossalai/shardformer/modeling/mixtral.py
  6. 14
      colossalai/shardformer/policies/mixtral.py
  7. 3
      colossalai/shardformer/shard/shard_config.py
  8. 8
      tests/test_moe/test_moe_ep_zero.py

6
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

@ -141,6 +141,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# set ep_group after super init
# TODO do it in a better way
self.shard_config.ep_group = self.ep_group
self.shard_config.moe_dp_group = self.moe_dp_group
self.shard_config.moe_tp_group = self.moe_tp_group
self.force_overlap_comm = force_overlap_comm
@ -159,7 +161,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# create groups from submesh
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
# axis 0 is dp, axis 1 is tp, axis 2 is sp
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
# hardcode here since we only have 3 axis
@ -188,7 +190,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
assert self.moe_tp_group is None
self.moe_tp_group = group
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}")
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", ranks=[0])
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(

51
colossalai/moe/_operation.py

@ -290,7 +290,7 @@ def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
return torch.cumsum(inputs, dim=0) - 1
class MoeInGradScaler(torch.autograd.Function):
class EPGradScalerIn(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
@ -298,8 +298,7 @@ class MoeInGradScaler(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
if ctx is not None:
ctx.ep_size = ep_size
ctx.ep_size = ep_size
return inputs
@staticmethod
@ -311,7 +310,7 @@ class MoeInGradScaler(torch.autograd.Function):
return grad, None
class MoeOutGradScaler(torch.autograd.Function):
class EPGradScalerOut(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
@ -331,6 +330,50 @@ class MoeOutGradScaler(torch.autograd.Function):
return grad, None
class DPGradScalerIn(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
ctx.moe_dp_size = moe_dp_size
ctx.activated_experts = activated_experts
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.moe_dp_size != ctx.activated_experts:
grad.mul_(ctx.activated_experts / ctx.moe_dp_size)
return grad, None, None
class DPGradScalerOut(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
ctx.moe_dp_size = moe_dp_size
ctx.activated_experts = activated_experts
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.moe_dp_size != ctx.activated_experts:
grad.mul_(ctx.moe_dp_size / ctx.activated_experts)
return grad, None, None
def _all_to_all(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,

6
colossalai/shardformer/layer/moe/experts.py

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
@ -118,7 +118,7 @@ class MLPExperts(nn.Module):
Returns:
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
"""
x = MoeInGradScaler.apply(x, self.ep_size)
x = EPGradScalerIn.apply(x, self.ep_size)
e = x.size(1)
h = x.size(-1)
@ -157,5 +157,5 @@ class MLPExperts(nn.Module):
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
x = MoeOutGradScaler.apply(x, self.ep_size)
x = EPGradScalerOut.apply(x, self.ep_size)
return x

6
colossalai/shardformer/layer/moe/routers.py

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
@ -118,7 +118,7 @@ class MLPExperts(nn.Module):
Returns:
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
"""
x = MoeInGradScaler.apply(x, self.ep_size)
x = EPGradScalerIn.apply(x, self.ep_size)
e = x.size(1)
h = x.size(-1)
@ -157,5 +157,5 @@ class MLPExperts(nn.Module):
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
x = MoeOutGradScaler.apply(x, self.ep_size)
x = EPGradScalerOut.apply(x, self.ep_size)
return x

46
colossalai/shardformer/modeling/mixtral.py

@ -14,18 +14,23 @@ from transformers.models.mixtral.modeling_mixtral import (
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven, drop_tokens, gather_tokens
from colossalai.moe._operation import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, drop_tokens, gather_tokens
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, config, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
super().__init__(config)
self.setup_process_groups(ep_group, tp_group, moe_tp_group)
def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
assert moe_tp_group is not None
def setup_process_groups(self, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
# setup ep group
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
@ -40,7 +45,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
p.ep_group = ep_group
set_moe_tensor_ep_group(p, ep_group)
# setup moe_dp group
self.moe_dp_group = moe_dp_group
self.moe_dp_size = moe_dp_group.size()
# setup global tp group
self.tp_group = tp_group
@ -50,11 +59,12 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
@staticmethod
def from_native_module(
module: MixtralSparseMoeBlock, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None, *args, **kwargs
module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup, *args, **kwargs
) -> "EPMixtralSparseMoeBlock":
# TODO: better init
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
module.setup_process_groups(ep_group)
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -76,36 +86,48 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
with torch.no_grad():
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
for i in range(1, self.ep_size):
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
activate_experts = (activate_experts > 0).float()
dist.all_reduce(activate_experts, group=self.moe_dp_group)
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
if self.tp_group is not None and self.tp_group.size() > 1:
if self.tp_group.size() > 1:
dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
# no need to split
expert = self.experts[self.expert_start_idx]
output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0].item())
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
output_states = expert.w2(output_states)
output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0].item())
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
split_states = DPGradScalerIn.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
split_states = DPGradScalerOut.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
if self.tp_group is not None and self.tp_group.size() > 1:
if self.tp_group.size() > 1:
dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
recover_experts_idx = torch.empty_like(selected_experts_idx)

14
colossalai/shardformer/policies/mixtral.py

@ -76,18 +76,6 @@ class MixtralPolicy(Policy):
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
# SubModuleReplacementDescription( # TODO: enable moe tp parallel
# suffix="mlp.gate_proj",
# target_module=Linear1D_Col,
# ),
# SubModuleReplacementDescription(
# suffix="mlp.up_proj",
# target_module=Linear1D_Col,
# ),
# SubModuleReplacementDescription(
# suffix="mlp.down_proj",
# target_module=Linear1D_Row,
# ),
],
)
@ -98,7 +86,7 @@ class MixtralPolicy(Policy):
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group},
kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, "moe_tp_group": self.shard_config.moe_tp_group},
)
],
policy=policy,

3
colossalai/shardformer/shard/shard_config.py

@ -46,6 +46,9 @@ class ShardConfig:
make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# for moe related
moe_dp_group: Optional[ProcessGroup] = None
ep_group: Optional[ProcessGroup] = None
moe_tp_group: Optional[ProcessGroup] = None

8
tests/test_moe/test_moe_zero_fwd_bwd_optim.py → tests/test_moe/test_moe_ep_zero.py

@ -18,8 +18,7 @@ NUM_BATCH=4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS=2
TOP_K = 2
TOP_K = 1
def split_grad(grad, world_size):
with torch.no_grad():
@ -96,7 +95,6 @@ def run_zero_with_original_model(stage: int, ep_size: int):
# check grad
name_to_p = {n: p for n, p in ddp_model.named_parameters()}
for n, p in zero_model.named_parameters():
print(f"rank {dist.get_rank()} {n}")
zero_grad = zero_optimizer.get_param_grad(p)
if name_to_p[n].grad is None:
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
@ -124,9 +122,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_ep_tp(world_size):
def test_moe_ep_zero(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_ep_tp(world_size=4)
test_moe_ep_zero(world_size=4)
Loading…
Cancel
Save