mirror of https://github.com/hpcaitech/ColossalAI
[moe] init moe plugin comm setting with sp
parent
09d6280d3e
commit
877d94bb8c
|
@ -1,6 +1,5 @@
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from copy import deepcopy
|
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Callable, Optional, OrderedDict, Tuple
|
from typing import Callable, Optional, OrderedDict, Tuple
|
||||||
|
|
||||||
|
@ -106,37 +105,35 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
|
|
||||||
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
|
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
|
||||||
if "overlap_communication" not in kwargs:
|
if "overlap_communication" not in kwargs:
|
||||||
kwargs["overlap_communication"] = False
|
kwargs["overlap_communication"] = False # default by true in super class
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
self.ep_size = ep_size
|
||||||
|
self.moe_tp_size = moe_tp_size
|
||||||
|
|
||||||
|
self._init_moe_param_comm()
|
||||||
|
|
||||||
|
self.use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||||
|
self.dp_size == 1
|
||||||
|
and self.pp_size == 1
|
||||||
|
and self.enable_sequence_parallelism
|
||||||
|
and self.sequence_parallelism_mode == "all_to_all"
|
||||||
|
)
|
||||||
|
|
||||||
if self.use_ddp:
|
if self.use_ddp:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
|
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
|
||||||
)
|
)
|
||||||
self.ddp_config["find_unused_parameters"] = True
|
self.ddp_config["find_unused_parameters"] = True
|
||||||
|
|
||||||
world_size = dist.get_world_size()
|
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
||||||
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size * self.sp_size)
|
raise ValueError(
|
||||||
self.ep_size = ep_size
|
f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to set ep_size=1 or zero_stage > 0"
|
||||||
self.moe_tp_size = moe_tp_size
|
)
|
||||||
|
|
||||||
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
|
# set ep_group after super().__init__()
|
||||||
raise ValueError(
|
|
||||||
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# self._init_moe_param_comm()
|
|
||||||
|
|
||||||
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
|
|
||||||
|
|
||||||
# set ep_group after super init
|
|
||||||
# TODO do it in a better way
|
# TODO do it in a better way
|
||||||
self.moe_dp_group = self.pp_group
|
|
||||||
self.ep_group = self.pp_group
|
|
||||||
self.moe_tp_group = self.pp_group
|
|
||||||
|
|
||||||
self.shard_config.ep_group = self.ep_group
|
self.shard_config.ep_group = self.ep_group
|
||||||
self.shard_config.moe_dp_group = self.moe_dp_group
|
self.shard_config.moe_dp_group = self.moe_dp_group
|
||||||
self.shard_config.moe_tp_group = self.moe_tp_group
|
self.shard_config.moe_tp_group = self.moe_tp_group
|
||||||
|
@ -144,48 +141,77 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
self.force_overlap_comm = force_overlap_comm
|
self.force_overlap_comm = force_overlap_comm
|
||||||
|
|
||||||
def _init_moe_param_comm(self):
|
def _init_moe_param_comm(self):
|
||||||
self.moe_dp_group = None
|
world_size = dist.get_world_size()
|
||||||
self.ep_group = None
|
|
||||||
self.moe_tp_group = None
|
|
||||||
|
|
||||||
# create submesh for ep, moe_dp, moe_tp
|
if self.enable_sequence_parallelism:
|
||||||
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
|
# if sequence parallelism is enabled, we reuse the same group for ep and sp
|
||||||
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
|
if self.sequence_parallelism_mode == "all_to_all":
|
||||||
)
|
# when sequence parallelism is enabled, ep_group reuses sp_group
|
||||||
|
if self.ep_size != self.sp_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} when sequence parallelism is enabled"
|
||||||
|
)
|
||||||
|
|
||||||
global_rank = self.pg_mesh.rank
|
self.moe_dp_size = self.dp_size
|
||||||
pp_rank = self.pg_mesh.coordinate(self.pp_axis)
|
self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters
|
||||||
|
self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||||
|
self.ep_group = self.sp_group
|
||||||
|
self.moe_tp_group = self.tp_group
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"sequence_parallelism_mode={self.sequence_parallelism_mode} is not supported"
|
||||||
|
)
|
||||||
|
|
||||||
# create groups from submesh
|
else:
|
||||||
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
|
self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size)
|
||||||
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
|
|
||||||
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
|
||||||
|
|
||||||
# hardcode here since we only have 3 axis
|
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
|
||||||
# moe_dp_group
|
raise ValueError(
|
||||||
for ep_idx in range(self.ep_size):
|
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
|
||||||
for moe_tp_idx in range(self.moe_tp_size):
|
)
|
||||||
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
|
|
||||||
group = dist.new_group(moe_dp_ranks)
|
self.moe_dp_group = None
|
||||||
if pp_rank == stage_idx and global_rank in moe_dp_ranks:
|
self.ep_group = None
|
||||||
assert self.moe_dp_group is None
|
self.moe_tp_group = None
|
||||||
self.moe_dp_group = group
|
|
||||||
# ep_group
|
# create submesh for ep, moe_dp, moe_tp
|
||||||
for moe_dp_idx in range(self.moe_dp_size):
|
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
|
||||||
for moe_tp_idx in range(self.moe_tp_size):
|
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
|
||||||
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
|
)
|
||||||
group = dist.new_group(ep_ranks)
|
|
||||||
if pp_rank == stage_idx and global_rank in ep_ranks:
|
global_rank = self.pg_mesh.rank
|
||||||
assert self.ep_group is None
|
pp_rank = self.pg_mesh.coordinate(self.pp_axis)
|
||||||
self.ep_group = group
|
|
||||||
# moe_tp_group
|
# create groups from submesh
|
||||||
for moe_dp_idx in range(self.moe_dp_size):
|
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
|
||||||
|
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
|
||||||
|
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
||||||
|
|
||||||
|
# hardcode here since we only have 3 axis
|
||||||
|
# moe_dp_group
|
||||||
for ep_idx in range(self.ep_size):
|
for ep_idx in range(self.ep_size):
|
||||||
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
|
for moe_tp_idx in range(self.moe_tp_size):
|
||||||
group = dist.new_group(moe_tp_ranks)
|
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
|
||||||
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
|
group = dist.new_group(moe_dp_ranks)
|
||||||
assert self.moe_tp_group is None
|
if pp_rank == stage_idx and global_rank in moe_dp_ranks:
|
||||||
self.moe_tp_group = group
|
assert self.moe_dp_group is None
|
||||||
|
self.moe_dp_group = group
|
||||||
|
# ep_group
|
||||||
|
for moe_dp_idx in range(self.moe_dp_size):
|
||||||
|
for moe_tp_idx in range(self.moe_tp_size):
|
||||||
|
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
|
||||||
|
group = dist.new_group(ep_ranks)
|
||||||
|
if pp_rank == stage_idx and global_rank in ep_ranks:
|
||||||
|
assert self.ep_group is None
|
||||||
|
self.ep_group = group
|
||||||
|
# moe_tp_group
|
||||||
|
for moe_dp_idx in range(self.moe_dp_size):
|
||||||
|
for ep_idx in range(self.ep_size):
|
||||||
|
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
|
||||||
|
group = dist.new_group(moe_tp_ranks)
|
||||||
|
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
|
||||||
|
assert self.moe_tp_group is None
|
||||||
|
self.moe_tp_group = group
|
||||||
|
|
||||||
if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
|
if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
|
||||||
# NOTE: different tp settings between moe and non moe param require complex comm logic, where all_to_all might not be suitable
|
# NOTE: different tp settings between moe and non moe param require complex comm logic, where all_to_all might not be suitable
|
||||||
|
@ -195,7 +221,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
|
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size}\n"
|
||||||
|
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
|
||||||
ranks=[0],
|
ranks=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -215,30 +242,18 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
param_info = get_param_info(optimizer)
|
param_info = get_param_info(optimizer)
|
||||||
|
|
||||||
# TODO: Support Galore + ZeRO
|
# TODO: Support Galore + ZeRO
|
||||||
self.zero_stage
|
|
||||||
deepcopy(self.zero_config)
|
|
||||||
# Replace with distributed implementation if exists
|
# Replace with distributed implementation if exists
|
||||||
optimizer = cast_to_distributed(optimizer)
|
optimizer = cast_to_distributed(optimizer)
|
||||||
|
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
|
||||||
self.dp_size == 1
|
|
||||||
and self.pp_size == 1
|
|
||||||
and self.enable_sequence_parallelism
|
|
||||||
and self.sequence_parallelism_mode == "all_to_all"
|
|
||||||
)
|
|
||||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
|
||||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
|
||||||
else:
|
|
||||||
dp_group = self.dp_group
|
|
||||||
model = HybridParallelModule(
|
model = HybridParallelModule(
|
||||||
module=model,
|
module=model,
|
||||||
precision=self.precision,
|
precision=self.precision,
|
||||||
shard_config=self.shard_config,
|
shard_config=self.shard_config,
|
||||||
dp_group=dp_group,
|
dp_group=self.dp_group,
|
||||||
tp_group=self.tp_group,
|
tp_group=self.tp_group,
|
||||||
sp_group=self.sp_group,
|
sp_group=self.sp_group,
|
||||||
use_ddp=use_ddp,
|
use_ddp=self.use_ddp,
|
||||||
ddp_config=self.ddp_config,
|
ddp_config=self.ddp_config,
|
||||||
custom_policy=self.custom_policy,
|
custom_policy=self.custom_policy,
|
||||||
)
|
)
|
||||||
|
@ -271,7 +286,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if not (self.dp_size > 1 or self.moe_dp_size > 1):
|
if self.dp_size <= 1:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||||
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
|
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
|
||||||
|
|
|
@ -10,13 +10,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.utils import is_flash_attn_2_available, logging
|
from transformers.utils import is_flash_attn_2_available, logging
|
||||||
|
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.moe._operation import (
|
from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
|
||||||
DPGradScalerIn,
|
|
||||||
DPGradScalerOut,
|
|
||||||
EPGradScalerIn,
|
|
||||||
EPGradScalerOut,
|
|
||||||
all_to_all_uneven,
|
|
||||||
)
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
|
@ -118,7 +118,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
selected_experts_idx = selected_experts.argsort()
|
selected_experts_idx = selected_experts.argsort()
|
||||||
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
|
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
|
||||||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||||
dist.get_rank()
|
|
||||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ NUM_HEADS = 4
|
||||||
TOP_K = 1
|
TOP_K = 1
|
||||||
|
|
||||||
|
|
||||||
@parameterize("config", [(1, 1, 1)])
|
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
|
||||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
stage, ep_size, tp_size = config
|
stage, ep_size, tp_size = config
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
|
|
@ -24,11 +24,10 @@ NUM_HEADS = 4
|
||||||
TOP_K = 1
|
TOP_K = 1
|
||||||
|
|
||||||
|
|
||||||
@parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)])
|
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
|
||||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
stage, ep_size, tp_size = config
|
stage, ep_size, tp_size = config
|
||||||
dtype = torch.float32
|
dtype, precision = torch.float16, "fp16"
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
torch.cuda.set_device(dist.get_rank())
|
torch.cuda.set_device(dist.get_rank())
|
||||||
|
|
||||||
|
@ -40,7 +39,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
zero_stage=stage,
|
zero_stage=stage,
|
||||||
overlap_communication=False,
|
overlap_communication=False,
|
||||||
initial_scale=1,
|
initial_scale=1,
|
||||||
precision="fp32",
|
precision=precision,
|
||||||
)
|
)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
@ -109,7 +108,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
saved_model = MixtralModel.from_pretrained(model_dir).cuda()
|
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
|
||||||
check_model_equal(torch_model, saved_model)
|
check_model_equal(torch_model, saved_model)
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
|
@ -26,9 +26,7 @@ top_k = 2
|
||||||
def check_model_equal(model1, model2):
|
def check_model_equal(model1, model2):
|
||||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||||
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
|
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
|
||||||
if loose_close(p1, p2, p1.dtype):
|
loose_close(p1, p2, p1.dtype)
|
||||||
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
|
|
||||||
raise AssertionError(f"Model parameter {name} is not equal")
|
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer_snapshot(optim):
|
def get_optimizer_snapshot(optim):
|
||||||
|
|
|
@ -141,12 +141,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
[
|
[
|
||||||
# {
|
# {
|
||||||
# "tp_size": 1,
|
# "tp_size": 1,
|
||||||
# "pp_size": 2,
|
# "pp_size": 1,
|
||||||
# "num_microbatches": 2,
|
# "num_microbatches": 2,
|
||||||
# "ep_size": 2,
|
# "ep_size": 2,
|
||||||
# "zero_stage": 1,
|
# "zero_stage": 0,
|
||||||
# "overlap_communication": False,
|
# "overlap_communication": False,
|
||||||
# "precision": "fp32",
|
# "precision": "fp16",
|
||||||
# }, # [dp(4)] + [moe_dp(4)]
|
# }, # [dp(4)] + [moe_dp(4)]
|
||||||
# {
|
# {
|
||||||
# "tp_size": 1,
|
# "tp_size": 1,
|
||||||
|
@ -169,7 +169,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
{ # Ulysess + Flash attention
|
{ # Ulysess + Flash attention
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"sp_size": 4,
|
"sp_size": 2,
|
||||||
"ep_size": 1,
|
"ep_size": 1,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "all_to_all",
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
|
|
Loading…
Reference in New Issue