mirror of https://github.com/hpcaitech/ColossalAI
update
parent
8aef2dba02
commit
f66469e209
|
@ -4,6 +4,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.moe import SparseMLP
|
from colossalai.moe import SparseMLP
|
||||||
|
from colossalai.tensor.moe_tensor.api import get_ep_rank, is_moe_tensor
|
||||||
|
|
||||||
|
|
||||||
class MixtralSparseMLP:
|
class MixtralSparseMLP:
|
||||||
|
@ -33,57 +34,81 @@ class MixtralSparseMLP:
|
||||||
Raises:
|
Raises:
|
||||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
||||||
"""
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
|
||||||
LazyInitContext.materialize(module)
|
# get the attributes of the module
|
||||||
# get the attributes of the module
|
moe_kwargs = dict(
|
||||||
moe_kwargs = dict(
|
num_experts=module.num_experts,
|
||||||
num_experts=module.num_experts,
|
hidden_size=module.hidden_dim,
|
||||||
hidden_size=module.hidden_dim,
|
intermediate_size=module.ffn_dim,
|
||||||
intermediate_size=module.ffn_dim,
|
router_top_k=module.top_k,
|
||||||
router_top_k=module.top_k,
|
router_norm=True,
|
||||||
# router_capacity_factor_train = module.
|
router_loss=False,
|
||||||
# router_capacity_factor_eval = module.
|
# router_capacity_factor_train = .
|
||||||
# router_min_capacity = module.
|
# router_capacity_factor_eval = .
|
||||||
# router_noisy_policy = module.
|
mlp_activation="silu",
|
||||||
# router_drop_tks = module.
|
mlp_gated=True,
|
||||||
mlp_activation="silu",
|
# enable_load_balance = .
|
||||||
mlp_gated=True,
|
# load_balance_tolerance = .
|
||||||
# enable_load_balance = module.
|
# load_balance_beam_width = .
|
||||||
# load_balance_tolerance = module.
|
# load_balance_group_swap_factor = .
|
||||||
# load_balance_beam_width = module.
|
# enable_kernel = .
|
||||||
# load_balance_group_swap_factor = module.
|
# enable_comm_overlap = .
|
||||||
# enable_kernel = module.
|
# enable_hierarchical_comm = .
|
||||||
# enable_comm_overlap = module.
|
return_gate_logits=True,
|
||||||
# enable_hierarchical_comm = module.
|
)
|
||||||
return_gate_logits=True,
|
dtype = module.gate.weight.dtype
|
||||||
)
|
device = module.gate.weight.device
|
||||||
dtype = module.gate.weight.dtype
|
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
|
||||||
device = module.gate.weight.device
|
|
||||||
|
|
||||||
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
|
# cat all experts
|
||||||
w1 = None
|
w1 = None
|
||||||
w2 = None
|
w2 = None
|
||||||
w3 = None
|
w3 = None
|
||||||
for i in module.experts:
|
for i in module.experts:
|
||||||
wi_1 = i.w1.weight.data.transpose(0, 1).unsqueeze(0)
|
# origin
|
||||||
wi_2 = i.w2.weight.data.transpose(0, 1).unsqueeze(0)
|
wi_1 = i.w1.weight.data.clone().transpose(0, 1).unsqueeze(0)
|
||||||
wi_3 = i.w3.weight.data.transpose(0, 1).unsqueeze(0)
|
wi_2 = i.w2.weight.data.clone().transpose(0, 1).unsqueeze(0)
|
||||||
if w1 is None:
|
wi_3 = i.w3.weight.data.clone().transpose(0, 1).unsqueeze(0)
|
||||||
w1 = wi_1
|
# cat
|
||||||
|
w1 = wi_1 if w1 is None else torch.cat([w1, wi_1], dim=0)
|
||||||
|
w2 = wi_2 if w2 is None else torch.cat([w2, wi_2], dim=0)
|
||||||
|
w3 = wi_3 if w3 is None else torch.cat([w3, wi_3], dim=0)
|
||||||
|
|
||||||
|
# get local experts
|
||||||
|
if is_moe_tensor(sparse_mlp.experts.wi_gate):
|
||||||
|
ep_rank = get_ep_rank(sparse_mlp.experts.wi_gate)
|
||||||
|
expert_num = sparse_mlp.experts.wi_gate.shape[0]
|
||||||
|
expert_slice = slice(ep_rank * expert_num, (ep_rank + 1) * expert_num)
|
||||||
else:
|
else:
|
||||||
w1 = torch.cat([w1, wi_1], dim=0)
|
expert_slice = slice(None)
|
||||||
if w2 is None:
|
w1 = w1[expert_slice].clone().detach()
|
||||||
w2 = wi_2
|
w2 = w2[expert_slice].clone().detach()
|
||||||
else:
|
w3 = w3[expert_slice].clone().detach()
|
||||||
w2 = torch.cat([w2, wi_2], dim=0)
|
assert (
|
||||||
if w3 is None:
|
w1.shape == sparse_mlp.experts.wi_gate.shape
|
||||||
w3 = wi_3
|
), f"current shape: {w1.shape}, target shape:{sparse_mlp.experts.wi_gate.shape}"
|
||||||
else:
|
assert (
|
||||||
w3 = torch.cat([w3, wi_3], dim=0)
|
w2.shape == sparse_mlp.experts.wo.shape
|
||||||
|
), f"current shape: {w2.shape}, target shape:{sparse_mlp.experts.wo.shape}"
|
||||||
|
assert (
|
||||||
|
w3.shape == sparse_mlp.experts.wi_up.shape
|
||||||
|
), f"current shape: {w3.shape}, target shape:{sparse_mlp.experts.wi_up.shape}"
|
||||||
|
|
||||||
sparse_mlp.experts.wi_gate.data = w1[:2]
|
# assign new param to colossal moe moudle
|
||||||
sparse_mlp.experts.wi_up.data = w3[:2]
|
sparse_mlp.experts.wi_gate.data = w1
|
||||||
sparse_mlp.experts.wo.data = w2[:2]
|
sparse_mlp.experts.wi_up.data = w3
|
||||||
sparse_mlp.gate_weight = module.gate.weight
|
sparse_mlp.experts.wo.data = w2
|
||||||
|
sparse_mlp.gate_weight = module.gate.weight
|
||||||
|
|
||||||
return sparse_mlp.to(dtype).to(device)
|
# TODO: fix
|
||||||
|
# the old weight is referenced somewhere so we can not del it.
|
||||||
|
# Change data pointer of old weight to release memory.
|
||||||
|
# The pointer will not be used and can be any pointer.
|
||||||
|
for i in module.experts:
|
||||||
|
i.w1.weight.data = w1
|
||||||
|
i.w2.weight.data = w2
|
||||||
|
i.w3.weight.data = w3
|
||||||
|
|
||||||
|
return sparse_mlp
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_requirements(path):
|
||||||
|
with open(path, "r") as fd:
|
||||||
|
return [r.strip() for r in fd.readlines()]
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_readme():
|
||||||
|
with open("README.md", encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_version():
|
||||||
|
with open("version.txt", "r") as f:
|
||||||
|
return f.read().strip()
|
||||||
|
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="colossal_moe",
|
||||||
|
version=fetch_version(),
|
||||||
|
packages=find_packages(
|
||||||
|
exclude=(
|
||||||
|
"tests",
|
||||||
|
"benchmarks",
|
||||||
|
"*.egg-info",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
description="Colossal-AI MoE",
|
||||||
|
long_description=fetch_readme(),
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
license="Apache Software License 2.0",
|
||||||
|
url="https://github.com/hpcaitech",
|
||||||
|
install_requires=fetch_requirements("requirements.txt"),
|
||||||
|
python_requires=">=3.6",
|
||||||
|
classifiers=[
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"License :: OSI Approved :: Apache Software License",
|
||||||
|
"Environment :: GPU :: NVIDIA CUDA",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: System :: Distributed Computing",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,31 @@
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from colossal_moe.models.mixtral_layer import MixtralSparseMLP
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_local_experts = num_local_experts
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
|
||||||
|
|
||||||
|
def test_moe_layer():
|
||||||
|
config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu")
|
||||||
|
mistral_moe = MixtralSparseMoeBlock(config).cuda()
|
||||||
|
colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda()
|
||||||
|
|
||||||
|
data = torch.randn(2, 8, 4).cuda()
|
||||||
|
mistral_output = mistral_moe(data)[0]
|
||||||
|
colossal_output = colossal_moe(data)[0]
|
||||||
|
assert torch.allclose(
|
||||||
|
mistral_output, colossal_output
|
||||||
|
), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_moe_layer()
|
|
@ -1,21 +1,20 @@
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoTokenizer, T5Tokenizer
|
from transformers import AutoTokenizer
|
||||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.moe.layers import apply_load_balance
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
from colossalai.moe import MOE_MANAGER, apply_load_balance
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,21 +22,6 @@ def move_to_cuda(batch, device):
|
||||||
return {k: v.to(device) for k, v in batch.items()}
|
return {k: v.to(device) for k, v in batch.items()}
|
||||||
|
|
||||||
|
|
||||||
def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict:
|
|
||||||
texts = ["<pad>" + sample["prompt"] + sample["completion"] for sample in batch]
|
|
||||||
data = tokenizer(
|
|
||||||
texts,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_length,
|
|
||||||
add_special_tokens=False,
|
|
||||||
)
|
|
||||||
data = {k: v.cuda() for k, v in data.items()}
|
|
||||||
data["labels"] = data["input_ids"].clone()
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class RandomDataset(Dataset):
|
class RandomDataset(Dataset):
|
||||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
|
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
|
||||||
self.num_samples = num_samples
|
self.num_samples = num_samples
|
||||||
|
@ -188,7 +172,6 @@ def main():
|
||||||
# Launch ColossalAI
|
# Launch ColossalAI
|
||||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||||
coordinator = DistCoordinator()
|
coordinator = DistCoordinator()
|
||||||
test_mode = args.model_name == "test"
|
|
||||||
|
|
||||||
# Set plugin
|
# Set plugin
|
||||||
booster_kwargs = {}
|
booster_kwargs = {}
|
||||||
|
@ -247,15 +230,20 @@ def main():
|
||||||
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||||
|
|
||||||
# Build OpenMoe model
|
# Build OpenMoe model
|
||||||
config = MixtralConfig(
|
# config = MixtralConfig(
|
||||||
hidden_size=32,
|
# hidden_size=2048,
|
||||||
intermediate_size=64,
|
# intermediate_size=4096,
|
||||||
num_hidden_layers=4,
|
# num_hidden_layers=4,
|
||||||
num_attention_heads=4,
|
# num_attention_heads=4,
|
||||||
num_key_value_heads=4,
|
# num_key_value_heads=4,
|
||||||
use_cache=False,
|
# use_cache=False,
|
||||||
)
|
# )
|
||||||
model = MixtralForCausalLM(config).bfloat16()
|
config = MixtralConfig.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||||
|
config.use_cache = False
|
||||||
|
# config.num_local_experts = 1
|
||||||
|
init_ctx = LazyInitContext(default_device=get_current_device())
|
||||||
|
with init_ctx:
|
||||||
|
model = MixtralForCausalLM(config).bfloat16()
|
||||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||||
|
|
||||||
# Enable gradient checkpointing
|
# Enable gradient checkpointing
|
||||||
|
@ -270,7 +258,7 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set optimizer
|
# Set optimizer
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
|
|
||||||
# Set booster
|
# Set booster
|
||||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
@ -292,7 +280,6 @@ def main():
|
||||||
) as pbar:
|
) as pbar:
|
||||||
for step in pbar:
|
for step in pbar:
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
exit()
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
outputs = booster.execute_pipeline(
|
outputs = booster.execute_pipeline(
|
||||||
train_dataloader_iter,
|
train_dataloader_iter,
|
||||||
|
@ -307,11 +294,9 @@ def main():
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
pbar.set_postfix({"loss": loss.item()})
|
pbar.set_postfix({"loss": loss.item()})
|
||||||
else:
|
else:
|
||||||
print("1111111\n\n")
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
data = next(train_dataloader_iter)
|
data = next(train_dataloader_iter)
|
||||||
data = move_to_cuda(data, torch.cuda.current_device())
|
data = move_to_cuda(data, torch.cuda.current_device())
|
||||||
print(data)
|
|
||||||
outputs = model(**data)
|
outputs = model(**data)
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
# Backward
|
# Backward
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
1.0.0
|
|
@ -1,6 +1,7 @@
|
||||||
from .checkpoint import MoECheckpintIO
|
from .checkpoint import MoECheckpintIO
|
||||||
from .experts import MLPExperts
|
from .experts import MLPExperts
|
||||||
from .layers import SparseMLP
|
from .layers import SparseMLP, apply_load_balance
|
||||||
|
from .manager import MOE_MANAGER
|
||||||
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
||||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
|
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
|
||||||
|
|
||||||
|
@ -14,4 +15,6 @@ __all__ = [
|
||||||
"UniformNoiseGenerator",
|
"UniformNoiseGenerator",
|
||||||
"SparseMLP",
|
"SparseMLP",
|
||||||
"MoECheckpintIO",
|
"MoECheckpintIO",
|
||||||
|
"MOE_MANAGER",
|
||||||
|
"apply_load_balance",
|
||||||
]
|
]
|
||||||
|
|
|
@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
|
||||||
self.ep_size = 1
|
self.ep_size = 1
|
||||||
|
|
||||||
if gated:
|
if gated:
|
||||||
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
|
self.wi_gate = nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
|
||||||
|
)
|
||||||
|
)
|
||||||
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||||
else:
|
else:
|
||||||
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||||
|
|
|
@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
router_top_k: int = 1,
|
router_top_k: int = 1,
|
||||||
|
router_loss: bool = True,
|
||||||
|
router_norm: bool = False,
|
||||||
router_capacity_factor_train: float = 1.25,
|
router_capacity_factor_train: float = 1.25,
|
||||||
router_capacity_factor_eval: float = 2.0,
|
router_capacity_factor_eval: float = 2.0,
|
||||||
router_min_capacity: int = 4,
|
router_min_capacity: int = 4,
|
||||||
|
@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
|
||||||
enable_kernel: bool = False,
|
enable_kernel: bool = False,
|
||||||
enable_comm_overlap: bool = False,
|
enable_comm_overlap: bool = False,
|
||||||
enable_hierarchical_comm: bool = False,
|
enable_hierarchical_comm: bool = False,
|
||||||
|
return_gate_logits: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.gated = mlp_gated
|
self.gated = mlp_gated
|
||||||
|
self.return_gate_logits = return_gate_logits
|
||||||
self.enable_kernel = enable_kernel
|
self.enable_kernel = enable_kernel
|
||||||
self.enable_comm_overlap = enable_comm_overlap
|
self.enable_comm_overlap = enable_comm_overlap
|
||||||
self.expert_parallel = MOE_MANAGER.get_parallel()
|
self.expert_parallel = MOE_MANAGER.get_parallel()
|
||||||
|
self.router_loss = router_loss
|
||||||
|
self.router_norm = router_norm
|
||||||
|
|
||||||
# moe router
|
# moe router
|
||||||
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
|
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
|
||||||
|
@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
|
||||||
tokens = inputs.reshape(-1, self.hidden_size)
|
tokens = inputs.reshape(-1, self.hidden_size)
|
||||||
|
|
||||||
# the data type of the inputs in the gating should be fp32
|
# the data type of the inputs in the gating should be fp32
|
||||||
fp32_input = tokens.to(torch.float)
|
gate_logits = F.linear(tokens, self.gate_weight)
|
||||||
fp32_weight = self.gate_weight.to(torch.float)
|
gate_output = gate_logits.to(torch.float)
|
||||||
gate_output = F.linear(fp32_input, fp32_weight)
|
|
||||||
|
|
||||||
# update expert load
|
# update expert load
|
||||||
if self.enable_load_balance == True:
|
if self.enable_load_balance == True:
|
||||||
|
@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
|
||||||
|
|
||||||
# the result from the router
|
# the result from the router
|
||||||
used_capacity, *route_result_list = self.router(
|
used_capacity, *route_result_list = self.router(
|
||||||
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
|
inputs=gate_output,
|
||||||
|
use_kernel=self.enable_kernel,
|
||||||
|
ep_group=self.ep_group,
|
||||||
|
use_loss=self.router_loss,
|
||||||
|
use_norm=self.router_norm,
|
||||||
|
)
|
||||||
|
|
||||||
# dispatch_data: (num_experts, capacity, hidden_size)
|
# dispatch_data: (num_experts, capacity, hidden_size)
|
||||||
if self.enable_kernel:
|
if self.enable_kernel:
|
||||||
|
@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
|
||||||
|
|
||||||
# expert_output: (num_groups, num_experts, capacity, hidden_size)
|
# expert_output: (num_groups, num_experts, capacity, hidden_size)
|
||||||
if self.expert_parallel == "EP":
|
if self.expert_parallel == "EP":
|
||||||
expert_output = self._ep_process(
|
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||||
dispatch_data,
|
|
||||||
used_capacity,
|
|
||||||
overlap=self.enable_comm_overlap
|
|
||||||
)
|
|
||||||
elif self.expert_parallel == "TP":
|
elif self.expert_parallel == "TP":
|
||||||
expert_output = self._tp_process(
|
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||||
dispatch_data,
|
|
||||||
used_capacity,
|
|
||||||
overlap=self.enable_comm_overlap
|
|
||||||
)
|
|
||||||
elif self.expert_parallel is None:
|
elif self.expert_parallel is None:
|
||||||
expert_output = self._local_process(dispatch_data)
|
expert_output = self._local_process(dispatch_data)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("This kind of communication has not been implemented yet.\n"
|
raise NotImplementedError(
|
||||||
"Please use Experts build function.")
|
"This kind of communication has not been implemented yet.\n" "Please use Experts build function."
|
||||||
|
)
|
||||||
|
|
||||||
if self.enable_kernel:
|
if self.enable_kernel:
|
||||||
expert_output = expert_output.reshape(-1, self.hidden_size)
|
expert_output = expert_output.reshape(-1, self.hidden_size)
|
||||||
|
@ -204,7 +207,11 @@ class SparseMLP(nn.Module):
|
||||||
ans = torch.matmul(combine_weights, expert_output)
|
ans = torch.matmul(combine_weights, expert_output)
|
||||||
|
|
||||||
ans = ans.reshape(inputs.shape)
|
ans = ans.reshape(inputs.shape)
|
||||||
return ans
|
|
||||||
|
if self.return_gate_logits:
|
||||||
|
return ans, gate_logits
|
||||||
|
else:
|
||||||
|
return ans
|
||||||
|
|
||||||
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
|
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
|
||||||
expert_in = expert_in.unsqueeze(0)
|
expert_in = expert_in.unsqueeze(0)
|
||||||
|
@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
|
||||||
return expert_out
|
return expert_out
|
||||||
|
|
||||||
def _ep_process(
|
def _ep_process(
|
||||||
self,
|
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
|
||||||
dispatch_data: torch.Tensor,
|
|
||||||
used_capacity: torch.Tensor,
|
|
||||||
overlap: bool = False
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Expert Parallel
|
Expert Parallel
|
||||||
|
@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
|
||||||
"""
|
"""
|
||||||
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
||||||
if self.ep_hierarchical_group is not None:
|
if self.ep_hierarchical_group is not None:
|
||||||
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
expert_input = HierarchicalAllToAll.apply(
|
||||||
|
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
|
||||||
|
)
|
||||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||||
expert_output = self.experts(expert_input)
|
expert_output = self.experts(expert_input)
|
||||||
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
expert_output = HierarchicalAllToAll.apply(
|
||||||
|
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
|
||||||
|
)
|
||||||
return expert_output
|
return expert_output
|
||||||
else:
|
else:
|
||||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
||||||
|
@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
|
||||||
NUM_CHUNK = 4
|
NUM_CHUNK = 4
|
||||||
NUM_STAGES = 4
|
NUM_STAGES = 4
|
||||||
|
|
||||||
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
|
assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
|
||||||
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
|
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
|
||||||
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||||
dispatch_data = dispatch_data.reshape(*input_shape)
|
dispatch_data = dispatch_data.reshape(*input_shape)
|
||||||
|
@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
|
||||||
for i in range(NUM_CHUNK + NUM_STAGES - 1):
|
for i in range(NUM_CHUNK + NUM_STAGES - 1):
|
||||||
if expert_out is not None:
|
if expert_out is not None:
|
||||||
expert_out.handle.wait()
|
expert_out.handle.wait()
|
||||||
output[:, :, offset:offset + chunk_size, :] = expert_out.data
|
output[:, :, offset : offset + chunk_size, :] = expert_out.data
|
||||||
offset += chunk_size
|
offset += chunk_size
|
||||||
expert_out = None
|
expert_out = None
|
||||||
|
|
||||||
# all2all last output
|
# all2all last output
|
||||||
if _expert_out is not None:
|
if _expert_out is not None:
|
||||||
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
|
expert_out = Capsule(
|
||||||
|
*AllToAll.apply(_expert_out.data, self.ep_group, True),
|
||||||
|
)
|
||||||
_expert_out = None
|
_expert_out = None
|
||||||
|
|
||||||
# all2all next input
|
# all2all next input
|
||||||
|
@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _tp_process(
|
def _tp_process(
|
||||||
self,
|
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
|
||||||
dispatch_data: torch.Tensor,
|
|
||||||
used_capacity: torch.Tensor,
|
|
||||||
overlap: bool = False
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
without overlap:
|
without overlap:
|
||||||
|
@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
|
||||||
NUM_CHUNK = 4
|
NUM_CHUNK = 4
|
||||||
NUM_STAGES = 4
|
NUM_STAGES = 4
|
||||||
|
|
||||||
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
|
assert (
|
||||||
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
dispatch_data.shape[0] % NUM_CHUNK == 0
|
||||||
|
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
||||||
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
|
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
|
||||||
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
|
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
|
||||||
output = torch.empty_like(dispatch_data)
|
output = torch.empty_like(dispatch_data)
|
||||||
|
|
|
@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC):
|
||||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
k_value: int,
|
self,
|
||||||
capacity_factor_train: float,
|
k_value: int,
|
||||||
capacity_factor_eval: float,
|
capacity_factor_train: float,
|
||||||
min_capacity: int,
|
capacity_factor_eval: float,
|
||||||
noisy_func: Optional[Callable] = None,
|
min_capacity: int,
|
||||||
drop_tks: bool = True,
|
noisy_func: Optional[Callable] = None,
|
||||||
use_kernel: bool = False):
|
drop_tks: bool = True,
|
||||||
|
use_kernel: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.k_value = k_value
|
self.k_value = k_value
|
||||||
self.capacity_factor_train = capacity_factor_train
|
self.capacity_factor_train = capacity_factor_train
|
||||||
|
@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC):
|
||||||
if router_probs.dim() == expert_indices.dim() == 2:
|
if router_probs.dim() == expert_indices.dim() == 2:
|
||||||
router_probs = router_probs.unsqueeze(0)
|
router_probs = router_probs.unsqueeze(0)
|
||||||
expert_indices = expert_indices.unsqueeze(0)
|
expert_indices = expert_indices.unsqueeze(0)
|
||||||
assert router_probs.dim() == expert_indices.dim() == 3, \
|
assert (
|
||||||
"router_probs must be 3D tensor and expert_indices must be 4D tensor"
|
router_probs.dim() == expert_indices.dim() == 3
|
||||||
|
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
|
||||||
|
|
||||||
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
|
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
|
||||||
expert_mask = F.one_hot(expert_indices, num_experts)
|
expert_mask = F.one_hot(expert_indices, num_experts)
|
||||||
|
@ -122,25 +125,28 @@ class Top1Router(MoeRouter):
|
||||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
capacity_factor_train: float = 1.25,
|
self,
|
||||||
capacity_factor_eval: float = 2.0,
|
capacity_factor_train: float = 1.25,
|
||||||
min_capacity: int = 4,
|
capacity_factor_eval: float = 2.0,
|
||||||
select_policy: str = "first",
|
min_capacity: int = 4,
|
||||||
noisy_func: Optional[Callable] = None,
|
select_policy: str = "first",
|
||||||
drop_tks: bool = True):
|
noisy_func: Optional[Callable] = None,
|
||||||
super().__init__(k_value=1,
|
drop_tks: bool = True,
|
||||||
capacity_factor_train=capacity_factor_train,
|
):
|
||||||
capacity_factor_eval=capacity_factor_eval,
|
super().__init__(
|
||||||
min_capacity=min_capacity,
|
k_value=1,
|
||||||
noisy_func=noisy_func,
|
capacity_factor_train=capacity_factor_train,
|
||||||
drop_tks=drop_tks)
|
capacity_factor_eval=capacity_factor_eval,
|
||||||
|
min_capacity=min_capacity,
|
||||||
|
noisy_func=noisy_func,
|
||||||
|
drop_tks=drop_tks,
|
||||||
|
)
|
||||||
self.select_policy = select_policy
|
self.select_policy = select_policy
|
||||||
assert select_policy in {"first", "random"}
|
assert select_policy in {"first", "random"}
|
||||||
if select_policy == "random":
|
if select_policy == "random":
|
||||||
self.uniform = torch.distributions.uniform.Uniform(
|
self.uniform = torch.distributions.uniform.Uniform(
|
||||||
low=torch.tensor(0.0, device=get_current_device()),
|
low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device())
|
||||||
high=torch.tensor(1.0, device=get_current_device())
|
|
||||||
).rsample
|
).rsample
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
||||||
|
@ -200,7 +206,7 @@ class Top1Router(MoeRouter):
|
||||||
weight = mask * probs.type_as(inputs)
|
weight = mask * probs.type_as(inputs)
|
||||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||||
sec_mask = combine_weights.bool()
|
sec_mask = combine_weights.bool()
|
||||||
return used_capacity, combine_weights, sec_mask
|
return used_capacity, combine_weights, sec_mask, probs
|
||||||
|
|
||||||
|
|
||||||
class Top2Router(MoeRouter):
|
class Top2Router(MoeRouter):
|
||||||
|
@ -216,20 +222,31 @@ class Top2Router(MoeRouter):
|
||||||
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
capacity_factor_train: float = 1.25,
|
self,
|
||||||
capacity_factor_eval: float = 2.0,
|
capacity_factor_train: float = 1.25,
|
||||||
min_capacity: int = 4,
|
capacity_factor_eval: float = 2.0,
|
||||||
noisy_func: Optional[Callable] = None,
|
min_capacity: int = 4,
|
||||||
drop_tks: bool = True):
|
noisy_func: Optional[Callable] = None,
|
||||||
super().__init__(k_value=2,
|
drop_tks: bool = True,
|
||||||
capacity_factor_train=capacity_factor_train,
|
):
|
||||||
capacity_factor_eval=capacity_factor_eval,
|
super().__init__(
|
||||||
min_capacity=min_capacity,
|
k_value=2,
|
||||||
noisy_func=noisy_func,
|
capacity_factor_train=capacity_factor_train,
|
||||||
drop_tks=drop_tks)
|
capacity_factor_eval=capacity_factor_eval,
|
||||||
|
min_capacity=min_capacity,
|
||||||
|
noisy_func=noisy_func,
|
||||||
|
drop_tks=drop_tks,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
use_kernel: bool = False,
|
||||||
|
ep_group: Optional[ProcessGroup] = None,
|
||||||
|
use_norm: bool = False,
|
||||||
|
use_loss: bool = True,
|
||||||
|
) -> Tuple:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||||
|
@ -246,6 +263,10 @@ class Top2Router(MoeRouter):
|
||||||
|
|
||||||
assert inputs.dtype == torch.float
|
assert inputs.dtype == torch.float
|
||||||
probs = F.softmax(inputs, dim=-1)
|
probs = F.softmax(inputs, dim=-1)
|
||||||
|
if use_norm:
|
||||||
|
routing_weights, _ = torch.topk(probs, 2, dim=-1)
|
||||||
|
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
num_experts = probs.size(-1)
|
num_experts = probs.size(-1)
|
||||||
capacity = self.get_capacity(inputs.shape)
|
capacity = self.get_capacity(inputs.shape)
|
||||||
|
|
||||||
|
@ -255,21 +276,22 @@ class Top2Router(MoeRouter):
|
||||||
top2_idx = torch.argmax(logits_except1, dim=-1)
|
top2_idx = torch.argmax(logits_except1, dim=-1)
|
||||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
||||||
|
|
||||||
cmask = (mask1 + mask2) # loss: [s, e]
|
cmask = mask1 + mask2 # loss: [s, e]
|
||||||
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
|
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
|
||||||
|
|
||||||
# calculate loss
|
# calculate loss
|
||||||
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
if use_loss:
|
||||||
self.set_aux_loss(probs, expert_indices, num_experts)
|
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
||||||
self.set_z_loss(inputs)
|
self.set_aux_loss(probs, expert_indices, num_experts)
|
||||||
self.pop_router_loss()
|
self.set_z_loss(inputs)
|
||||||
|
self.pop_router_loss()
|
||||||
|
|
||||||
if not self.training and not self.drop_tks and ep_group is not None:
|
if not self.training and not self.drop_tks and ep_group is not None:
|
||||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||||
capacity = max_num.item()
|
capacity = max_num.item()
|
||||||
|
|
||||||
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
|
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
|
||||||
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
|
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
|
||||||
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
||||||
|
|
||||||
|
@ -336,15 +358,18 @@ class TopKRouter(MoeRouter):
|
||||||
oversubscribed / reach capacity.
|
oversubscribed / reach capacity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
num_selected_experts: int,
|
self,
|
||||||
capacity_factor_train: float = 1.25,
|
num_selected_experts: int,
|
||||||
capacity_factor_eval: float = 2.0,
|
capacity_factor_train: float = 1.25,
|
||||||
min_capacity: int = 4,
|
capacity_factor_eval: float = 2.0,
|
||||||
noisy_func: Optional[Callable] = None,
|
min_capacity: int = 4,
|
||||||
drop_tks: bool = True):
|
noisy_func: Optional[Callable] = None,
|
||||||
super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func,
|
drop_tks: bool = True,
|
||||||
drop_tks)
|
):
|
||||||
|
super().__init__(
|
||||||
|
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -410,7 +435,7 @@ class TopKRouter(MoeRouter):
|
||||||
# The combine array will be used for combining expert outputs, scaled by the
|
# The combine array will be used for combining expert outputs, scaled by the
|
||||||
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
|
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
|
||||||
# expert_capacity].
|
# expert_capacity].
|
||||||
combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask)
|
combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
|
||||||
|
|
||||||
return combine_array, dispatch_mask
|
return combine_array, dispatch_mask
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,6 @@ from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
class ForceFP32Parameter(torch.nn.Parameter):
|
class ForceFP32Parameter(torch.nn.Parameter):
|
||||||
|
|
||||||
def half(self, memory_format=None):
|
def half(self, memory_format=None):
|
||||||
return self.data.clone()
|
return self.data.clone()
|
||||||
|
|
||||||
|
@ -84,6 +83,8 @@ def get_activation(act: str) -> Callable:
|
||||||
return torch.nn.GELU()
|
return torch.nn.GELU()
|
||||||
elif act == "swiglu":
|
elif act == "swiglu":
|
||||||
return SwiGLU
|
return SwiGLU
|
||||||
|
elif act == "silu":
|
||||||
|
return torch.nn.SiLU()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unsupported activation function")
|
raise NotImplementedError("Unsupported activation function")
|
||||||
|
|
||||||
|
@ -142,7 +143,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
|
||||||
epsize_param_dict = dict()
|
epsize_param_dict = dict()
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
if not is_moe_tensor(param):
|
if not is_moe_tensor(param):
|
||||||
ep_size = 1 # set ep_size to 1 for dp parameters
|
ep_size = 1 # set ep_size to 1 for dp parameters
|
||||||
else:
|
else:
|
||||||
ep_size = get_ep_size(param)
|
ep_size = get_ep_size(param)
|
||||||
if ep_size not in epsize_param_dict:
|
if ep_size not in epsize_param_dict:
|
||||||
|
@ -193,18 +194,13 @@ def create_ep_hierarchical_group(
|
||||||
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
|
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
|
||||||
nproc_per_node = int(nproc_per_node)
|
nproc_per_node = int(nproc_per_node)
|
||||||
else:
|
else:
|
||||||
assert dist.get_world_size() % nproc_per_node == 0, \
|
assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size."
|
||||||
"nproc_per_node should be a divisor of world_size."
|
|
||||||
num_node = dist.get_world_size() // nproc_per_node
|
num_node = dist.get_world_size() // nproc_per_node
|
||||||
|
|
||||||
intra_src_rank = None
|
intra_src_rank = None
|
||||||
ep_intra_node_group = None
|
ep_intra_node_group = None
|
||||||
for i in range(num_node):
|
for i in range(num_node):
|
||||||
ep_intra_ranks = [
|
ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks]
|
||||||
i * nproc_per_node + j
|
|
||||||
for j in range(nproc_per_node)
|
|
||||||
if j in ep_group_ranks
|
|
||||||
]
|
|
||||||
group = dist.new_group(ep_intra_ranks)
|
group = dist.new_group(ep_intra_ranks)
|
||||||
if rank in ep_intra_ranks:
|
if rank in ep_intra_ranks:
|
||||||
assert ep_intra_node_group is None
|
assert ep_intra_node_group is None
|
||||||
|
@ -212,10 +208,7 @@ def create_ep_hierarchical_group(
|
||||||
intra_src_rank = ep_intra_ranks[0]
|
intra_src_rank = ep_intra_ranks[0]
|
||||||
|
|
||||||
ep_inter_node_group = None
|
ep_inter_node_group = None
|
||||||
ep_inter_ranks = [
|
ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)]
|
||||||
ep_group_ranks[0] + i * nproc_per_node
|
|
||||||
for i in range(num_node)
|
|
||||||
]
|
|
||||||
if len(ep_inter_ranks) > 1:
|
if len(ep_inter_ranks) > 1:
|
||||||
group = dist.new_group(ep_inter_ranks)
|
group = dist.new_group(ep_inter_ranks)
|
||||||
if rank in ep_inter_ranks:
|
if rank in ep_inter_ranks:
|
||||||
|
|
Loading…
Reference in New Issue