mirror of https://github.com/hpcaitech/ColossalAI
[moe] add mixtral dp grad scaling when not all experts are activated
parent
e28e05345b
commit
9b9b76bdcd
|
@ -141,6 +141,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
# set ep_group after super init
|
||||
# TODO do it in a better way
|
||||
self.shard_config.ep_group = self.ep_group
|
||||
self.shard_config.moe_dp_group = self.moe_dp_group
|
||||
self.shard_config.moe_tp_group = self.moe_tp_group
|
||||
|
||||
self.force_overlap_comm = force_overlap_comm
|
||||
|
||||
|
@ -159,7 +161,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
|
||||
# create groups from submesh
|
||||
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
|
||||
# axis 0 is dp, axis 1 is tp, axis 2 is sp
|
||||
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
|
||||
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
||||
|
||||
# hardcode here since we only have 3 axis
|
||||
|
@ -188,7 +190,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
assert self.moe_tp_group is None
|
||||
self.moe_tp_group = group
|
||||
|
||||
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}")
|
||||
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", ranks=[0])
|
||||
|
||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||
return MoECheckpointIO(
|
||||
|
|
|
@ -290,7 +290,7 @@ def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
|
|||
return torch.cumsum(inputs, dim=0) - 1
|
||||
|
||||
|
||||
class MoeInGradScaler(torch.autograd.Function):
|
||||
class EPGradScalerIn(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient back by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
|
@ -298,8 +298,7 @@ class MoeInGradScaler(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.ep_size = ep_size
|
||||
ctx.ep_size = ep_size
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
|
@ -311,7 +310,7 @@ class MoeInGradScaler(torch.autograd.Function):
|
|||
return grad, None
|
||||
|
||||
|
||||
class MoeOutGradScaler(torch.autograd.Function):
|
||||
class EPGradScalerOut(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
|
@ -331,6 +330,50 @@ class MoeOutGradScaler(torch.autograd.Function):
|
|||
return grad, None
|
||||
|
||||
|
||||
class DPGradScalerIn(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient back by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
|
||||
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
|
||||
ctx.moe_dp_size = moe_dp_size
|
||||
ctx.activated_experts = activated_experts
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
|
||||
assert len(grad_outputs) == 1
|
||||
grad = grad_outputs[0]
|
||||
if ctx.moe_dp_size != ctx.activated_experts:
|
||||
grad.mul_(ctx.activated_experts / ctx.moe_dp_size)
|
||||
return grad, None, None
|
||||
|
||||
|
||||
class DPGradScalerOut(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
|
||||
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
|
||||
ctx.moe_dp_size = moe_dp_size
|
||||
ctx.activated_experts = activated_experts
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
|
||||
assert len(grad_outputs) == 1
|
||||
grad = grad_outputs[0]
|
||||
if ctx.moe_dp_size != ctx.activated_experts:
|
||||
grad.mul_(ctx.moe_dp_size / ctx.activated_experts)
|
||||
return grad, None, None
|
||||
|
||||
|
||||
def _all_to_all(
|
||||
inputs: torch.Tensor,
|
||||
input_split_sizes: Optional[List[int]] = None,
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
|
||||
from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import get_activation
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
|
@ -118,7 +118,7 @@ class MLPExperts(nn.Module):
|
|||
Returns:
|
||||
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
|
||||
"""
|
||||
x = MoeInGradScaler.apply(x, self.ep_size)
|
||||
x = EPGradScalerIn.apply(x, self.ep_size)
|
||||
|
||||
e = x.size(1)
|
||||
h = x.size(-1)
|
||||
|
@ -157,5 +157,5 @@ class MLPExperts(nn.Module):
|
|||
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
|
||||
x = x.reshape(inshape)
|
||||
x = x.transpose(0, 1).contiguous()
|
||||
x = MoeOutGradScaler.apply(x, self.ep_size)
|
||||
x = EPGradScalerOut.apply(x, self.ep_size)
|
||||
return x
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
|
||||
from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import get_activation
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
|
@ -118,7 +118,7 @@ class MLPExperts(nn.Module):
|
|||
Returns:
|
||||
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
|
||||
"""
|
||||
x = MoeInGradScaler.apply(x, self.ep_size)
|
||||
x = EPGradScalerIn.apply(x, self.ep_size)
|
||||
|
||||
e = x.size(1)
|
||||
h = x.size(-1)
|
||||
|
@ -157,5 +157,5 @@ class MLPExperts(nn.Module):
|
|||
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
|
||||
x = x.reshape(inshape)
|
||||
x = x.transpose(0, 1).contiguous()
|
||||
x = MoeOutGradScaler.apply(x, self.ep_size)
|
||||
x = EPGradScalerOut.apply(x, self.ep_size)
|
||||
return x
|
||||
|
|
|
@ -14,18 +14,23 @@ from transformers.models.mixtral.modeling_mixtral import (
|
|||
from transformers.utils import is_flash_attn_2_available, logging
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven, drop_tokens, gather_tokens
|
||||
from colossalai.moe._operation import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, drop_tokens, gather_tokens
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
||||
|
||||
|
||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
def __init__(self, config, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
|
||||
super().__init__(config)
|
||||
self.setup_process_groups(ep_group, tp_group, moe_tp_group)
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
assert moe_tp_group is not None
|
||||
|
||||
def setup_process_groups(self, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
|
||||
# setup ep group
|
||||
self.ep_size = dist.get_world_size(ep_group)
|
||||
self.ep_rank = dist.get_rank(ep_group)
|
||||
|
@ -40,7 +45,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
|
||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||
for p in self.experts.parameters():
|
||||
p.ep_group = ep_group
|
||||
set_moe_tensor_ep_group(p, ep_group)
|
||||
|
||||
# setup moe_dp group
|
||||
self.moe_dp_group = moe_dp_group
|
||||
self.moe_dp_size = moe_dp_group.size()
|
||||
|
||||
# setup global tp group
|
||||
self.tp_group = tp_group
|
||||
|
@ -50,11 +59,12 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: MixtralSparseMoeBlock, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None, *args, **kwargs
|
||||
module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup, *args, **kwargs
|
||||
) -> "EPMixtralSparseMoeBlock":
|
||||
# TODO: better init
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
module.setup_process_groups(ep_group)
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -76,36 +86,48 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
with torch.no_grad():
|
||||
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
|
||||
for i in range(1, self.ep_size):
|
||||
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
|
||||
activate_experts = (activate_experts > 0).float()
|
||||
dist.all_reduce(activate_experts, group=self.moe_dp_group)
|
||||
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
|
||||
if self.tp_group is not None and self.tp_group.size() > 1:
|
||||
if self.tp_group.size() > 1:
|
||||
dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
|
||||
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
# compute expert output
|
||||
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
||||
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
||||
if output_states.size(0) > 0:
|
||||
if self.num_experts_per_ep == 1:
|
||||
# no need to split
|
||||
expert = self.experts[self.expert_start_idx]
|
||||
output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0].item())
|
||||
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
|
||||
output_states = expert.w2(output_states)
|
||||
output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0].item())
|
||||
else:
|
||||
output_states_splits = output_states.split(output_split_sizes.tolist())
|
||||
output_states_list = []
|
||||
for i, split_states in enumerate(output_states_splits):
|
||||
if split_states.size(0) == 0:
|
||||
continue
|
||||
split_states = DPGradScalerIn.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
|
||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
||||
split_states = expert.w2(split_states)
|
||||
split_states = DPGradScalerOut.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
||||
|
||||
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
|
||||
if self.tp_group is not None and self.tp_group.size() > 1:
|
||||
if self.tp_group.size() > 1:
|
||||
dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
|
||||
|
||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||
|
|
|
@ -76,18 +76,6 @@ class MixtralPolicy(Policy):
|
|||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
# SubModuleReplacementDescription( # TODO: enable moe tp parallel
|
||||
# suffix="mlp.gate_proj",
|
||||
# target_module=Linear1D_Col,
|
||||
# ),
|
||||
# SubModuleReplacementDescription(
|
||||
# suffix="mlp.up_proj",
|
||||
# target_module=Linear1D_Col,
|
||||
# ),
|
||||
# SubModuleReplacementDescription(
|
||||
# suffix="mlp.down_proj",
|
||||
# target_module=Linear1D_Row,
|
||||
# ),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -98,7 +86,7 @@ class MixtralPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group},
|
||||
kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, "moe_tp_group": self.shard_config.moe_tp_group},
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
|
|
|
@ -46,6 +46,9 @@ class ShardConfig:
|
|||
make_vocab_size_divisible_by: int = 64
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# for moe related
|
||||
moe_dp_group: Optional[ProcessGroup] = None
|
||||
ep_group: Optional[ProcessGroup] = None
|
||||
moe_tp_group: Optional[ProcessGroup] = None
|
||||
|
||||
|
|
|
@ -18,8 +18,7 @@ NUM_BATCH=4
|
|||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
|
||||
HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS=2
|
||||
TOP_K = 2
|
||||
|
||||
TOP_K = 1
|
||||
|
||||
def split_grad(grad, world_size):
|
||||
with torch.no_grad():
|
||||
|
@ -96,7 +95,6 @@ def run_zero_with_original_model(stage: int, ep_size: int):
|
|||
# check grad
|
||||
name_to_p = {n: p for n, p in ddp_model.named_parameters()}
|
||||
for n, p in zero_model.named_parameters():
|
||||
print(f"rank {dist.get_rank()} {n}")
|
||||
zero_grad = zero_optimizer.get_param_grad(p)
|
||||
if name_to_p[n].grad is None:
|
||||
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
|
||||
|
@ -124,9 +122,9 @@ def run_dist(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_ep_tp(world_size):
|
||||
def test_moe_ep_zero(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_ep_tp(world_size=4)
|
||||
test_moe_ep_zero(world_size=4)
|
Loading…
Reference in New Issue