[moe] init moe plugin comm setting with sp

colossalchat
hxwang 2024-07-18 08:37:06 +00:00 committed by Hongxin Liu
parent 09d6280d3e
commit 877d94bb8c
7 changed files with 101 additions and 95 deletions

View File

@ -1,6 +1,5 @@
import warnings
from collections import defaultdict
from copy import deepcopy
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
@ -106,37 +105,35 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
if "overlap_communication" not in kwargs:
kwargs["overlap_communication"] = False
kwargs["overlap_communication"] = False # default by true in super class
super().__init__(*args, **kwargs)
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
self.ep_size = ep_size
self.moe_tp_size = moe_tp_size
self._init_moe_param_comm()
self.use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
if self.use_ddp:
warnings.warn(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
)
self.ddp_config["find_unused_parameters"] = True
world_size = dist.get_world_size()
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size * self.sp_size)
self.ep_size = ep_size
self.moe_tp_size = moe_tp_size
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
raise ValueError(
f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to set ep_size=1 or zero_stage > 0"
)
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
raise ValueError(
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
)
# self._init_moe_param_comm()
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
# set ep_group after super init
# set ep_group after super().__init__()
# TODO do it in a better way
self.moe_dp_group = self.pp_group
self.ep_group = self.pp_group
self.moe_tp_group = self.pp_group
self.shard_config.ep_group = self.ep_group
self.shard_config.moe_dp_group = self.moe_dp_group
self.shard_config.moe_tp_group = self.moe_tp_group
@ -144,48 +141,77 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.force_overlap_comm = force_overlap_comm
def _init_moe_param_comm(self):
self.moe_dp_group = None
self.ep_group = None
self.moe_tp_group = None
world_size = dist.get_world_size()
# create submesh for ep, moe_dp, moe_tp
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
)
if self.enable_sequence_parallelism:
# if sequence parallelism is enabled, we reuse the same group for ep and sp
if self.sequence_parallelism_mode == "all_to_all":
# when sequence parallelism is enabled, ep_group reuses sp_group
if self.ep_size != self.sp_size:
raise ValueError(
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} when sequence parallelism is enabled"
)
global_rank = self.pg_mesh.rank
pp_rank = self.pg_mesh.coordinate(self.pp_axis)
self.moe_dp_size = self.dp_size
self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters
self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.ep_group = self.sp_group
self.moe_tp_group = self.tp_group
else:
raise NotImplementedError(
f"sequence_parallelism_mode={self.sequence_parallelism_mode} is not supported"
)
# create groups from submesh
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
else:
self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size)
# hardcode here since we only have 3 axis
# moe_dp_group
for ep_idx in range(self.ep_size):
for moe_tp_idx in range(self.moe_tp_size):
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
group = dist.new_group(moe_dp_ranks)
if pp_rank == stage_idx and global_rank in moe_dp_ranks:
assert self.moe_dp_group is None
self.moe_dp_group = group
# ep_group
for moe_dp_idx in range(self.moe_dp_size):
for moe_tp_idx in range(self.moe_tp_size):
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
group = dist.new_group(ep_ranks)
if pp_rank == stage_idx and global_rank in ep_ranks:
assert self.ep_group is None
self.ep_group = group
# moe_tp_group
for moe_dp_idx in range(self.moe_dp_size):
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
raise ValueError(
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
)
self.moe_dp_group = None
self.ep_group = None
self.moe_tp_group = None
# create submesh for ep, moe_dp, moe_tp
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
)
global_rank = self.pg_mesh.rank
pp_rank = self.pg_mesh.coordinate(self.pp_axis)
# create groups from submesh
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
# hardcode here since we only have 3 axis
# moe_dp_group
for ep_idx in range(self.ep_size):
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
group = dist.new_group(moe_tp_ranks)
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
assert self.moe_tp_group is None
self.moe_tp_group = group
for moe_tp_idx in range(self.moe_tp_size):
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
group = dist.new_group(moe_dp_ranks)
if pp_rank == stage_idx and global_rank in moe_dp_ranks:
assert self.moe_dp_group is None
self.moe_dp_group = group
# ep_group
for moe_dp_idx in range(self.moe_dp_size):
for moe_tp_idx in range(self.moe_tp_size):
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
group = dist.new_group(ep_ranks)
if pp_rank == stage_idx and global_rank in ep_ranks:
assert self.ep_group is None
self.ep_group = group
# moe_tp_group
for moe_dp_idx in range(self.moe_dp_size):
for ep_idx in range(self.ep_size):
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
group = dist.new_group(moe_tp_ranks)
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
assert self.moe_tp_group is None
self.moe_tp_group = group
if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
# NOTE: different tp settings between moe and non moe param require complex comm logic, where all_to_all might not be suitable
@ -195,7 +221,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
self.logger.info(
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size}\n"
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
ranks=[0],
)
@ -215,30 +242,18 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
param_info = get_param_info(optimizer)
# TODO: Support Galore + ZeRO
self.zero_stage
deepcopy(self.zero_config)
# Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
dp_group = self.dp_group
model = HybridParallelModule(
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=dp_group,
dp_group=self.dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
use_ddp=self.use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
@ -271,7 +286,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
tp_process_group=self.tp_group,
)
else:
if not (self.dp_size > 1 or self.moe_dp_size > 1):
if self.dp_size <= 1:
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0."

View File

@ -10,13 +10,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import (
DPGradScalerIn,
DPGradScalerOut,
EPGradScalerIn,
EPGradScalerOut,
all_to_all_uneven,
)
from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig

View File

@ -118,7 +118,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
dist.get_rank()
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)

View File

@ -23,7 +23,7 @@ NUM_HEADS = 4
TOP_K = 1
@parameterize("config", [(1, 1, 1)])
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config
dtype = torch.float16

View File

@ -24,11 +24,10 @@ NUM_HEADS = 4
TOP_K = 1
@parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)])
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config
dtype = torch.float32
dtype, precision = torch.float16, "fp16"
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
@ -40,7 +39,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
zero_stage=stage,
overlap_communication=False,
initial_scale=1,
precision="fp32",
precision=precision,
)
booster = Booster(plugin=plugin)
@ -109,7 +108,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
dist.barrier()
saved_model = MixtralModel.from_pretrained(model_dir).cuda()
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(torch_model, saved_model)
dist.barrier()

View File

@ -26,9 +26,7 @@ top_k = 2
def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
if loose_close(p1, p2, p1.dtype):
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
raise AssertionError(f"Model parameter {name} is not equal")
loose_close(p1, p2, p1.dtype)
def get_optimizer_snapshot(optim):

View File

@ -141,12 +141,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[
# {
# "tp_size": 1,
# "pp_size": 2,
# "pp_size": 1,
# "num_microbatches": 2,
# "ep_size": 2,
# "zero_stage": 1,
# "zero_stage": 0,
# "overlap_communication": False,
# "precision": "fp32",
# "precision": "fp16",
# }, # [dp(4)] + [moe_dp(4)]
# {
# "tp_size": 1,
@ -169,7 +169,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 1,
"sp_size": 4,
"sp_size": 2,
"ep_size": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",