mirror of https://github.com/hpcaitech/ColossalAI
[hotfix]: modify create_ep_hierarchical_group and add test (#5032)
* feat: modify create_ep_hierarchical_group args * test: add ep tests * fix: remove get_process_group_ranks * fix: fix src_rankpull/5060/head
parent
97cd0cd559
commit
3c08f17348
|
@ -150,7 +150,8 @@ class HierarchicalAllToAll(torch.autograd.Function):
|
||||||
def forward(
|
def forward(
|
||||||
ctx: Any,
|
ctx: Any,
|
||||||
inputs: Tensor,
|
inputs: Tensor,
|
||||||
groups: Tuple[ProcessGroup],
|
groups: Tuple[ProcessGroup, ProcessGroup],
|
||||||
|
src_rank: int
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -159,12 +160,12 @@ class HierarchicalAllToAll(torch.autograd.Function):
|
||||||
# TODO: we can reduce comm volume by removing empty capacity
|
# TODO: we can reduce comm volume by removing empty capacity
|
||||||
if ctx is not None:
|
if ctx is not None:
|
||||||
ctx.comm_grps = groups
|
ctx.comm_grps = groups
|
||||||
|
ctx.src_rank = src_rank
|
||||||
intra_node_group, inter_node_group = groups
|
intra_node_group, inter_node_group = groups
|
||||||
|
|
||||||
local_world_size = dist.get_world_size(intra_node_group)
|
local_world_size = dist.get_world_size(intra_node_group)
|
||||||
num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1
|
num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1
|
||||||
world_size = local_world_size * num_group
|
world_size = local_world_size * num_group
|
||||||
src_rank = dist.get_process_group_ranks(intra_node_group)[0]
|
|
||||||
outputs = torch.empty_like(inputs)
|
outputs = torch.empty_like(inputs)
|
||||||
|
|
||||||
if dist.get_rank() == src_rank:
|
if dist.get_rank() == src_rank:
|
||||||
|
@ -196,9 +197,10 @@ class HierarchicalAllToAll(torch.autograd.Function):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]:
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||||
return (
|
return (
|
||||||
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps),
|
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps, ctx.src_rank),
|
||||||
|
None,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from colossalai.moe.load_balance import LoadBalancer
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
from colossalai.moe.routers import MoeRouter, get_router_cls
|
from colossalai.moe.routers import MoeRouter, get_router_cls
|
||||||
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
|
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
|
||||||
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size
|
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
|
||||||
|
|
||||||
|
|
||||||
class SparseMLP(nn.Module):
|
class SparseMLP(nn.Module):
|
||||||
|
@ -105,8 +105,11 @@ class SparseMLP(nn.Module):
|
||||||
if self.expert_parallel is not None:
|
if self.expert_parallel is not None:
|
||||||
self.ep_group = get_ep_group(self.experts)
|
self.ep_group = get_ep_group(self.experts)
|
||||||
self.ep_size = get_ep_size(self.experts)
|
self.ep_size = get_ep_size(self.experts)
|
||||||
self.ep_hierarchical_group = create_ep_hierarchical_group(
|
self.ep_hierarchical_group = None
|
||||||
self.ep_group) if enable_hierarchical_comm else None
|
if enable_hierarchical_comm:
|
||||||
|
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
|
||||||
|
get_ep_group_ranks(self.experts)
|
||||||
|
)
|
||||||
self.dp_group = get_dp_group(self.experts)
|
self.dp_group = get_dp_group(self.experts)
|
||||||
else:
|
else:
|
||||||
self.ep_group = None
|
self.ep_group = None
|
||||||
|
@ -225,10 +228,10 @@ class SparseMLP(nn.Module):
|
||||||
"""
|
"""
|
||||||
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
||||||
if self.ep_hierarchical_group is not None:
|
if self.ep_hierarchical_group is not None:
|
||||||
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group)
|
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
||||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||||
expert_output = self.experts(expert_input)
|
expert_output = self.experts(expert_input)
|
||||||
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group)
|
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
||||||
return expert_output
|
return expert_output
|
||||||
else:
|
else:
|
||||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
||||||
|
|
|
@ -179,15 +179,15 @@ def set_moe_args(config: Any, args: dict):
|
||||||
|
|
||||||
|
|
||||||
def create_ep_hierarchical_group(
|
def create_ep_hierarchical_group(
|
||||||
ep_group: dist.ProcessGroup,
|
ep_group_ranks: List[int],
|
||||||
nproc_per_node: Optional[int] = None,
|
nproc_per_node: Optional[int] = None,
|
||||||
) -> Tuple[Optional[dist.ProcessGroup],
|
) -> Tuple[int, dist.ProcessGroup, Optional[dist.ProcessGroup]]:
|
||||||
Optional[dist.ProcessGroup]]:
|
|
||||||
"""
|
"""
|
||||||
e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4
|
e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4
|
||||||
Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None
|
Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None
|
||||||
"""
|
"""
|
||||||
assert dist.is_initialized(), "Please initialize torch.distributed first."
|
assert dist.is_initialized(), "Please initialize torch.distributed first."
|
||||||
|
rank = dist.get_rank()
|
||||||
if nproc_per_node is None:
|
if nproc_per_node is None:
|
||||||
nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE")
|
nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE")
|
||||||
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
|
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
|
||||||
|
@ -197,24 +197,23 @@ def create_ep_hierarchical_group(
|
||||||
"nproc_per_node should be a divisor of world_size."
|
"nproc_per_node should be a divisor of world_size."
|
||||||
num_node = dist.get_world_size() // nproc_per_node
|
num_node = dist.get_world_size() // nproc_per_node
|
||||||
|
|
||||||
rank = dist.get_rank()
|
intra_src_rank = None
|
||||||
ep_ranks = dist.get_process_group_ranks(ep_group)
|
|
||||||
|
|
||||||
ep_intra_node_group = None
|
ep_intra_node_group = None
|
||||||
for i in range(num_node):
|
for i in range(num_node):
|
||||||
ep_intra_ranks = [
|
ep_intra_ranks = [
|
||||||
i * nproc_per_node + j
|
i * nproc_per_node + j
|
||||||
for j in range(nproc_per_node)
|
for j in range(nproc_per_node)
|
||||||
if j in ep_ranks
|
if j in ep_group_ranks
|
||||||
]
|
]
|
||||||
group = dist.new_group(ep_intra_ranks)
|
group = dist.new_group(ep_intra_ranks)
|
||||||
if rank in ep_intra_ranks:
|
if rank in ep_intra_ranks:
|
||||||
assert ep_intra_node_group is None
|
assert ep_intra_node_group is None
|
||||||
ep_intra_node_group = group
|
ep_intra_node_group = group
|
||||||
|
intra_src_rank = ep_intra_ranks[0]
|
||||||
|
|
||||||
ep_inter_node_group = None
|
ep_inter_node_group = None
|
||||||
ep_inter_ranks = [
|
ep_inter_ranks = [
|
||||||
ep_ranks[0] + i * nproc_per_node
|
ep_group_ranks[0] + i * nproc_per_node
|
||||||
for i in range(num_node)
|
for i in range(num_node)
|
||||||
]
|
]
|
||||||
if len(ep_inter_ranks) > 1:
|
if len(ep_inter_ranks) > 1:
|
||||||
|
@ -222,4 +221,4 @@ def create_ep_hierarchical_group(
|
||||||
if rank in ep_inter_ranks:
|
if rank in ep_inter_ranks:
|
||||||
ep_inter_node_group = group
|
ep_inter_node_group = group
|
||||||
|
|
||||||
return ep_intra_node_group, ep_inter_node_group
|
return intra_src_rank, ep_intra_node_group, ep_inter_node_group
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
@ -124,7 +126,7 @@ def get_dp_rank(tensor: torch.Tensor) -> int:
|
||||||
return dist.get_rank(get_dp_group(tensor))
|
return dist.get_rank(get_dp_group(tensor))
|
||||||
|
|
||||||
|
|
||||||
def get_ep_group_ranks(tensor: torch.Tensor) -> int:
|
def get_ep_group_ranks(tensor: torch.Tensor) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Get the expert parallel group ranks of the given tensor.
|
Get the expert parallel group ranks of the given tensor.
|
||||||
|
|
||||||
|
@ -137,7 +139,7 @@ def get_ep_group_ranks(tensor: torch.Tensor) -> int:
|
||||||
return tensor.moe_info.ep_group_ranks
|
return tensor.moe_info.ep_group_ranks
|
||||||
|
|
||||||
|
|
||||||
def get_dp_group_ranks(tensor: torch.Tensor) -> int:
|
def get_dp_group_ranks(tensor: torch.Tensor) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Get the data parallel group ranks of the given tensor.
|
Get the data parallel group ranks of the given tensor.
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -123,7 +124,7 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
|
||||||
local_param.data.copy_(all_param.data)
|
local_param.data.copy_(all_param.data)
|
||||||
|
|
||||||
|
|
||||||
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int):
|
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
|
||||||
assert batch_size % world_size == 0
|
assert batch_size % world_size == 0
|
||||||
|
|
||||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
@ -133,8 +134,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
||||||
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
||||||
MOE_MANAGER.__init__()
|
MOE_MANAGER.__init__()
|
||||||
MOE_MANAGER.setup(parallel="EP")
|
MOE_MANAGER.setup(parallel="EP")
|
||||||
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
|
enable_hierarchical_comm = config.get("enable_hierarchical_comm", False)
|
||||||
enable_hierarchical_comm = torch.__version__ >= "1.13.1"
|
if enable_hierarchical_comm:
|
||||||
|
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
|
||||||
ep_model = SparseMLP(
|
ep_model = SparseMLP(
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
hidden_size=dim,
|
hidden_size=dim,
|
||||||
|
@ -161,7 +163,6 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
||||||
tp_grad_handler = MoeGradientHandler(tp_model)
|
tp_grad_handler = MoeGradientHandler(tp_model)
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
input_data = torch.randn(batch_size, dim, device=get_current_device())
|
input_data = torch.randn(batch_size, dim, device=get_current_device())
|
||||||
micro_batch_size = batch_size // world_size
|
micro_batch_size = batch_size // world_size
|
||||||
index = rank * micro_batch_size
|
index = rank * micro_batch_size
|
||||||
|
@ -218,11 +219,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
||||||
@pytest.mark.parametrize("num_experts", [4, 64])
|
@pytest.mark.parametrize("num_experts", [4, 64])
|
||||||
@pytest.mark.parametrize("batch_size", [16])
|
@pytest.mark.parametrize("batch_size", [16])
|
||||||
@pytest.mark.parametrize("dim", [64])
|
@pytest.mark.parametrize("dim", [64])
|
||||||
@pytest.mark.parametrize("seed", [42, 127])
|
@pytest.mark.parametrize("config", [
|
||||||
|
{"enable_hierarchical_comm": False},
|
||||||
|
{"enable_hierarchical_comm": True},
|
||||||
|
])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int):
|
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
|
||||||
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed)
|
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32, seed=42)
|
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)
|
||||||
|
|
Loading…
Reference in New Issue