[hotfix]: modify create_ep_hierarchical_group and add test (#5032)

* feat: modify create_ep_hierarchical_group args

* test: add ep tests

* fix: remove get_process_group_ranks

* fix: fix src_rank
pull/5060/head
Wenhao Chen 2023-11-17 10:53:00 +08:00 committed by GitHub
parent 97cd0cd559
commit 3c08f17348
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 28 deletions

View File

@ -150,7 +150,8 @@ class HierarchicalAllToAll(torch.autograd.Function):
def forward( def forward(
ctx: Any, ctx: Any,
inputs: Tensor, inputs: Tensor,
groups: Tuple[ProcessGroup], groups: Tuple[ProcessGroup, ProcessGroup],
src_rank: int
) -> Tensor: ) -> Tensor:
""" """
Returns: Returns:
@ -159,12 +160,12 @@ class HierarchicalAllToAll(torch.autograd.Function):
# TODO: we can reduce comm volume by removing empty capacity # TODO: we can reduce comm volume by removing empty capacity
if ctx is not None: if ctx is not None:
ctx.comm_grps = groups ctx.comm_grps = groups
ctx.src_rank = src_rank
intra_node_group, inter_node_group = groups intra_node_group, inter_node_group = groups
local_world_size = dist.get_world_size(intra_node_group) local_world_size = dist.get_world_size(intra_node_group)
num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1 num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1
world_size = local_world_size * num_group world_size = local_world_size * num_group
src_rank = dist.get_process_group_ranks(intra_node_group)[0]
outputs = torch.empty_like(inputs) outputs = torch.empty_like(inputs)
if dist.get_rank() == src_rank: if dist.get_rank() == src_rank:
@ -196,9 +197,10 @@ class HierarchicalAllToAll(torch.autograd.Function):
return outputs return outputs
@staticmethod @staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return ( return (
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps), HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps, ctx.src_rank),
None,
None, None,
) )

View File

@ -13,7 +13,7 @@ from colossalai.moe.load_balance import LoadBalancer
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.routers import MoeRouter, get_router_cls from colossalai.moe.routers import MoeRouter, get_router_cls
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
class SparseMLP(nn.Module): class SparseMLP(nn.Module):
@ -105,8 +105,11 @@ class SparseMLP(nn.Module):
if self.expert_parallel is not None: if self.expert_parallel is not None:
self.ep_group = get_ep_group(self.experts) self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts) self.ep_size = get_ep_size(self.experts)
self.ep_hierarchical_group = create_ep_hierarchical_group( self.ep_hierarchical_group = None
self.ep_group) if enable_hierarchical_comm else None if enable_hierarchical_comm:
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
get_ep_group_ranks(self.experts)
)
self.dp_group = get_dp_group(self.experts) self.dp_group = get_dp_group(self.experts)
else: else:
self.ep_group = None self.ep_group = None
@ -225,10 +228,10 @@ class SparseMLP(nn.Module):
""" """
if not overlap or dist.get_world_size(self.ep_group) == 1: if not overlap or dist.get_world_size(self.ep_group) == 1:
if self.ep_hierarchical_group is not None: if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group) expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input) expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group) expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
return expert_output return expert_output
else: else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]

View File

@ -179,15 +179,15 @@ def set_moe_args(config: Any, args: dict):
def create_ep_hierarchical_group( def create_ep_hierarchical_group(
ep_group: dist.ProcessGroup, ep_group_ranks: List[int],
nproc_per_node: Optional[int] = None, nproc_per_node: Optional[int] = None,
) -> Tuple[Optional[dist.ProcessGroup], ) -> Tuple[int, dist.ProcessGroup, Optional[dist.ProcessGroup]]:
Optional[dist.ProcessGroup]]:
""" """
e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4 e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4
Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None
""" """
assert dist.is_initialized(), "Please initialize torch.distributed first." assert dist.is_initialized(), "Please initialize torch.distributed first."
rank = dist.get_rank()
if nproc_per_node is None: if nproc_per_node is None:
nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE") nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE")
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually." assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
@ -197,24 +197,23 @@ def create_ep_hierarchical_group(
"nproc_per_node should be a divisor of world_size." "nproc_per_node should be a divisor of world_size."
num_node = dist.get_world_size() // nproc_per_node num_node = dist.get_world_size() // nproc_per_node
rank = dist.get_rank() intra_src_rank = None
ep_ranks = dist.get_process_group_ranks(ep_group)
ep_intra_node_group = None ep_intra_node_group = None
for i in range(num_node): for i in range(num_node):
ep_intra_ranks = [ ep_intra_ranks = [
i * nproc_per_node + j i * nproc_per_node + j
for j in range(nproc_per_node) for j in range(nproc_per_node)
if j in ep_ranks if j in ep_group_ranks
] ]
group = dist.new_group(ep_intra_ranks) group = dist.new_group(ep_intra_ranks)
if rank in ep_intra_ranks: if rank in ep_intra_ranks:
assert ep_intra_node_group is None assert ep_intra_node_group is None
ep_intra_node_group = group ep_intra_node_group = group
intra_src_rank = ep_intra_ranks[0]
ep_inter_node_group = None ep_inter_node_group = None
ep_inter_ranks = [ ep_inter_ranks = [
ep_ranks[0] + i * nproc_per_node ep_group_ranks[0] + i * nproc_per_node
for i in range(num_node) for i in range(num_node)
] ]
if len(ep_inter_ranks) > 1: if len(ep_inter_ranks) > 1:
@ -222,4 +221,4 @@ def create_ep_hierarchical_group(
if rank in ep_inter_ranks: if rank in ep_inter_ranks:
ep_inter_node_group = group ep_inter_node_group = group
return ep_intra_node_group, ep_inter_node_group return intra_src_rank, ep_intra_node_group, ep_inter_node_group

View File

@ -1,3 +1,5 @@
from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -124,7 +126,7 @@ def get_dp_rank(tensor: torch.Tensor) -> int:
return dist.get_rank(get_dp_group(tensor)) return dist.get_rank(get_dp_group(tensor))
def get_ep_group_ranks(tensor: torch.Tensor) -> int: def get_ep_group_ranks(tensor: torch.Tensor) -> List[int]:
""" """
Get the expert parallel group ranks of the given tensor. Get the expert parallel group ranks of the given tensor.
@ -137,7 +139,7 @@ def get_ep_group_ranks(tensor: torch.Tensor) -> int:
return tensor.moe_info.ep_group_ranks return tensor.moe_info.ep_group_ranks
def get_dp_group_ranks(tensor: torch.Tensor) -> int: def get_dp_group_ranks(tensor: torch.Tensor) -> List[int]:
""" """
Get the data parallel group ranks of the given tensor. Get the data parallel group ranks of the given tensor.

View File

@ -1,5 +1,6 @@
import os import os
import warnings import warnings
from typing import Dict
import pytest import pytest
import torch import torch
@ -123,7 +124,7 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
local_param.data.copy_(all_param.data) local_param.data.copy_(all_param.data)
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int): def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
assert batch_size % world_size == 0 assert batch_size % world_size == 0
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
@ -133,8 +134,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__() MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP") MOE_MANAGER.setup(parallel="EP")
os.environ["LOCAL_WORLD_SIZE"] = str(world_size) enable_hierarchical_comm = config.get("enable_hierarchical_comm", False)
enable_hierarchical_comm = torch.__version__ >= "1.13.1" if enable_hierarchical_comm:
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
ep_model = SparseMLP( ep_model = SparseMLP(
num_experts=num_experts, num_experts=num_experts,
hidden_size=dim, hidden_size=dim,
@ -161,7 +163,6 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
tp_grad_handler = MoeGradientHandler(tp_model) tp_grad_handler = MoeGradientHandler(tp_model)
rank = dist.get_rank() rank = dist.get_rank()
torch.cuda.manual_seed(seed)
input_data = torch.randn(batch_size, dim, device=get_current_device()) input_data = torch.randn(batch_size, dim, device=get_current_device())
micro_batch_size = batch_size // world_size micro_batch_size = batch_size // world_size
index = rank * micro_batch_size index = rank * micro_batch_size
@ -218,11 +219,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
@pytest.mark.parametrize("num_experts", [4, 64]) @pytest.mark.parametrize("num_experts", [4, 64])
@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("dim", [64]) @pytest.mark.parametrize("dim", [64])
@pytest.mark.parametrize("seed", [42, 127]) @pytest.mark.parametrize("config", [
{"enable_hierarchical_comm": False},
{"enable_hierarchical_comm": True},
])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int): def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed) spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
if __name__ == '__main__': if __name__ == '__main__':
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32, seed=42) test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)