mirror of https://github.com/hpcaitech/ColossalAI
[hotfix]: modify create_ep_hierarchical_group and add test (#5032)
* feat: modify create_ep_hierarchical_group args * test: add ep tests * fix: remove get_process_group_ranks * fix: fix src_rankpull/5060/head
parent
97cd0cd559
commit
3c08f17348
|
@ -150,7 +150,8 @@ class HierarchicalAllToAll(torch.autograd.Function):
|
|||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
groups: Tuple[ProcessGroup],
|
||||
groups: Tuple[ProcessGroup, ProcessGroup],
|
||||
src_rank: int
|
||||
) -> Tensor:
|
||||
"""
|
||||
Returns:
|
||||
|
@ -159,12 +160,12 @@ class HierarchicalAllToAll(torch.autograd.Function):
|
|||
# TODO: we can reduce comm volume by removing empty capacity
|
||||
if ctx is not None:
|
||||
ctx.comm_grps = groups
|
||||
ctx.src_rank = src_rank
|
||||
intra_node_group, inter_node_group = groups
|
||||
|
||||
local_world_size = dist.get_world_size(intra_node_group)
|
||||
num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1
|
||||
world_size = local_world_size * num_group
|
||||
src_rank = dist.get_process_group_ranks(intra_node_group)[0]
|
||||
outputs = torch.empty_like(inputs)
|
||||
|
||||
if dist.get_rank() == src_rank:
|
||||
|
@ -196,9 +197,10 @@ class HierarchicalAllToAll(torch.autograd.Function):
|
|||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]:
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
return (
|
||||
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps),
|
||||
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps, ctx.src_rank),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from colossalai.moe.load_balance import LoadBalancer
|
|||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.routers import MoeRouter, get_router_cls
|
||||
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
|
||||
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size
|
||||
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
|
||||
|
||||
|
||||
class SparseMLP(nn.Module):
|
||||
|
@ -105,8 +105,11 @@ class SparseMLP(nn.Module):
|
|||
if self.expert_parallel is not None:
|
||||
self.ep_group = get_ep_group(self.experts)
|
||||
self.ep_size = get_ep_size(self.experts)
|
||||
self.ep_hierarchical_group = create_ep_hierarchical_group(
|
||||
self.ep_group) if enable_hierarchical_comm else None
|
||||
self.ep_hierarchical_group = None
|
||||
if enable_hierarchical_comm:
|
||||
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
|
||||
get_ep_group_ranks(self.experts)
|
||||
)
|
||||
self.dp_group = get_dp_group(self.experts)
|
||||
else:
|
||||
self.ep_group = None
|
||||
|
@ -225,10 +228,10 @@ class SparseMLP(nn.Module):
|
|||
"""
|
||||
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
||||
if self.ep_hierarchical_group is not None:
|
||||
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group)
|
||||
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||
expert_output = self.experts(expert_input)
|
||||
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group)
|
||||
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
||||
return expert_output
|
||||
else:
|
||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
||||
|
|
|
@ -179,15 +179,15 @@ def set_moe_args(config: Any, args: dict):
|
|||
|
||||
|
||||
def create_ep_hierarchical_group(
|
||||
ep_group: dist.ProcessGroup,
|
||||
ep_group_ranks: List[int],
|
||||
nproc_per_node: Optional[int] = None,
|
||||
) -> Tuple[Optional[dist.ProcessGroup],
|
||||
Optional[dist.ProcessGroup]]:
|
||||
) -> Tuple[int, dist.ProcessGroup, Optional[dist.ProcessGroup]]:
|
||||
"""
|
||||
e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4
|
||||
Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None
|
||||
"""
|
||||
assert dist.is_initialized(), "Please initialize torch.distributed first."
|
||||
rank = dist.get_rank()
|
||||
if nproc_per_node is None:
|
||||
nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE")
|
||||
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
|
||||
|
@ -197,24 +197,23 @@ def create_ep_hierarchical_group(
|
|||
"nproc_per_node should be a divisor of world_size."
|
||||
num_node = dist.get_world_size() // nproc_per_node
|
||||
|
||||
rank = dist.get_rank()
|
||||
ep_ranks = dist.get_process_group_ranks(ep_group)
|
||||
|
||||
intra_src_rank = None
|
||||
ep_intra_node_group = None
|
||||
for i in range(num_node):
|
||||
ep_intra_ranks = [
|
||||
i * nproc_per_node + j
|
||||
for j in range(nproc_per_node)
|
||||
if j in ep_ranks
|
||||
if j in ep_group_ranks
|
||||
]
|
||||
group = dist.new_group(ep_intra_ranks)
|
||||
if rank in ep_intra_ranks:
|
||||
assert ep_intra_node_group is None
|
||||
ep_intra_node_group = group
|
||||
intra_src_rank = ep_intra_ranks[0]
|
||||
|
||||
ep_inter_node_group = None
|
||||
ep_inter_ranks = [
|
||||
ep_ranks[0] + i * nproc_per_node
|
||||
ep_group_ranks[0] + i * nproc_per_node
|
||||
for i in range(num_node)
|
||||
]
|
||||
if len(ep_inter_ranks) > 1:
|
||||
|
@ -222,4 +221,4 @@ def create_ep_hierarchical_group(
|
|||
if rank in ep_inter_ranks:
|
||||
ep_inter_node_group = group
|
||||
|
||||
return ep_intra_node_group, ep_inter_node_group
|
||||
return intra_src_rank, ep_intra_node_group, ep_inter_node_group
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
@ -124,7 +126,7 @@ def get_dp_rank(tensor: torch.Tensor) -> int:
|
|||
return dist.get_rank(get_dp_group(tensor))
|
||||
|
||||
|
||||
def get_ep_group_ranks(tensor: torch.Tensor) -> int:
|
||||
def get_ep_group_ranks(tensor: torch.Tensor) -> List[int]:
|
||||
"""
|
||||
Get the expert parallel group ranks of the given tensor.
|
||||
|
||||
|
@ -137,7 +139,7 @@ def get_ep_group_ranks(tensor: torch.Tensor) -> int:
|
|||
return tensor.moe_info.ep_group_ranks
|
||||
|
||||
|
||||
def get_dp_group_ranks(tensor: torch.Tensor) -> int:
|
||||
def get_dp_group_ranks(tensor: torch.Tensor) -> List[int]:
|
||||
"""
|
||||
Get the data parallel group ranks of the given tensor.
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import warnings
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -123,7 +124,7 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
|
|||
local_param.data.copy_(all_param.data)
|
||||
|
||||
|
||||
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int):
|
||||
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
|
||||
assert batch_size % world_size == 0
|
||||
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
@ -133,8 +134,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
|||
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(parallel="EP")
|
||||
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
|
||||
enable_hierarchical_comm = torch.__version__ >= "1.13.1"
|
||||
enable_hierarchical_comm = config.get("enable_hierarchical_comm", False)
|
||||
if enable_hierarchical_comm:
|
||||
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
|
||||
ep_model = SparseMLP(
|
||||
num_experts=num_experts,
|
||||
hidden_size=dim,
|
||||
|
@ -161,7 +163,6 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
|||
tp_grad_handler = MoeGradientHandler(tp_model)
|
||||
|
||||
rank = dist.get_rank()
|
||||
torch.cuda.manual_seed(seed)
|
||||
input_data = torch.randn(batch_size, dim, device=get_current_device())
|
||||
micro_batch_size = batch_size // world_size
|
||||
index = rank * micro_batch_size
|
||||
|
@ -218,11 +219,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
|||
@pytest.mark.parametrize("num_experts", [4, 64])
|
||||
@pytest.mark.parametrize("batch_size", [16])
|
||||
@pytest.mark.parametrize("dim", [64])
|
||||
@pytest.mark.parametrize("seed", [42, 127])
|
||||
@pytest.mark.parametrize("config", [
|
||||
{"enable_hierarchical_comm": False},
|
||||
{"enable_hierarchical_comm": True},
|
||||
])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int):
|
||||
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed)
|
||||
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
|
||||
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32, seed=42)
|
||||
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)
|
||||
|
|
Loading…
Reference in New Issue