[moe] init moe plugin comm setting with sp

colossalchat
hxwang 2024-07-18 08:37:06 +00:00 committed by Hongxin Liu
parent 09d6280d3e
commit 877d94bb8c
7 changed files with 101 additions and 95 deletions

View File

@ -1,6 +1,5 @@
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from copy import deepcopy
from types import MethodType from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple from typing import Callable, Optional, OrderedDict, Tuple
@ -106,37 +105,35 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None: def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
if "overlap_communication" not in kwargs: if "overlap_communication" not in kwargs:
kwargs["overlap_communication"] = False kwargs["overlap_communication"] = False # default by true in super class
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 self.ep_size = ep_size
self.moe_tp_size = moe_tp_size
self._init_moe_param_comm()
self.use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
if self.use_ddp: if self.use_ddp:
warnings.warn( warnings.warn(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated" f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
) )
self.ddp_config["find_unused_parameters"] = True self.ddp_config["find_unused_parameters"] = True
world_size = dist.get_world_size() if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size * self.sp_size) raise ValueError(
self.ep_size = ep_size f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to set ep_size=1 or zero_stage > 0"
self.moe_tp_size = moe_tp_size )
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size: # set ep_group after super().__init__()
raise ValueError(
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
)
# self._init_moe_param_comm()
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
# set ep_group after super init
# TODO do it in a better way # TODO do it in a better way
self.moe_dp_group = self.pp_group
self.ep_group = self.pp_group
self.moe_tp_group = self.pp_group
self.shard_config.ep_group = self.ep_group self.shard_config.ep_group = self.ep_group
self.shard_config.moe_dp_group = self.moe_dp_group self.shard_config.moe_dp_group = self.moe_dp_group
self.shard_config.moe_tp_group = self.moe_tp_group self.shard_config.moe_tp_group = self.moe_tp_group
@ -144,48 +141,77 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.force_overlap_comm = force_overlap_comm self.force_overlap_comm = force_overlap_comm
def _init_moe_param_comm(self): def _init_moe_param_comm(self):
self.moe_dp_group = None world_size = dist.get_world_size()
self.ep_group = None
self.moe_tp_group = None
# create submesh for ep, moe_dp, moe_tp if self.enable_sequence_parallelism:
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis( # if sequence parallelism is enabled, we reuse the same group for ep and sp
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True if self.sequence_parallelism_mode == "all_to_all":
) # when sequence parallelism is enabled, ep_group reuses sp_group
if self.ep_size != self.sp_size:
raise ValueError(
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} when sequence parallelism is enabled"
)
global_rank = self.pg_mesh.rank self.moe_dp_size = self.dp_size
pp_rank = self.pg_mesh.coordinate(self.pp_axis) self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters
self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.ep_group = self.sp_group
self.moe_tp_group = self.tp_group
else:
raise NotImplementedError(
f"sequence_parallelism_mode={self.sequence_parallelism_mode} is not supported"
)
# create groups from submesh else:
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage): self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size)
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
# hardcode here since we only have 3 axis if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
# moe_dp_group raise ValueError(
for ep_idx in range(self.ep_size): f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
for moe_tp_idx in range(self.moe_tp_size): )
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
group = dist.new_group(moe_dp_ranks) self.moe_dp_group = None
if pp_rank == stage_idx and global_rank in moe_dp_ranks: self.ep_group = None
assert self.moe_dp_group is None self.moe_tp_group = None
self.moe_dp_group = group
# ep_group # create submesh for ep, moe_dp, moe_tp
for moe_dp_idx in range(self.moe_dp_size): ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
for moe_tp_idx in range(self.moe_tp_size): [self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist() )
group = dist.new_group(ep_ranks)
if pp_rank == stage_idx and global_rank in ep_ranks: global_rank = self.pg_mesh.rank
assert self.ep_group is None pp_rank = self.pg_mesh.coordinate(self.pp_axis)
self.ep_group = group
# moe_tp_group # create groups from submesh
for moe_dp_idx in range(self.moe_dp_size): for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
# hardcode here since we only have 3 axis
# moe_dp_group
for ep_idx in range(self.ep_size): for ep_idx in range(self.ep_size):
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist() for moe_tp_idx in range(self.moe_tp_size):
group = dist.new_group(moe_tp_ranks) moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
if pp_rank == stage_idx and global_rank in moe_tp_ranks: group = dist.new_group(moe_dp_ranks)
assert self.moe_tp_group is None if pp_rank == stage_idx and global_rank in moe_dp_ranks:
self.moe_tp_group = group assert self.moe_dp_group is None
self.moe_dp_group = group
# ep_group
for moe_dp_idx in range(self.moe_dp_size):
for moe_tp_idx in range(self.moe_tp_size):
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
group = dist.new_group(ep_ranks)
if pp_rank == stage_idx and global_rank in ep_ranks:
assert self.ep_group is None
self.ep_group = group
# moe_tp_group
for moe_dp_idx in range(self.moe_dp_size):
for ep_idx in range(self.ep_size):
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
group = dist.new_group(moe_tp_ranks)
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
assert self.moe_tp_group is None
self.moe_tp_group = group
if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group): if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
# NOTE: different tp settings between moe and non moe param require complex comm logic, where all_to_all might not be suitable # NOTE: different tp settings between moe and non moe param require complex comm logic, where all_to_all might not be suitable
@ -195,7 +221,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
) )
self.logger.info( self.logger.info(
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size}\n"
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
ranks=[0], ranks=[0],
) )
@ -215,30 +242,18 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
param_info = get_param_info(optimizer) param_info = get_param_info(optimizer)
# TODO: Support Galore + ZeRO # TODO: Support Galore + ZeRO
self.zero_stage
deepcopy(self.zero_config)
# Replace with distributed implementation if exists # Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer) optimizer = cast_to_distributed(optimizer)
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
dp_group = self.dp_group
model = HybridParallelModule( model = HybridParallelModule(
module=model, module=model,
precision=self.precision, precision=self.precision,
shard_config=self.shard_config, shard_config=self.shard_config,
dp_group=dp_group, dp_group=self.dp_group,
tp_group=self.tp_group, tp_group=self.tp_group,
sp_group=self.sp_group, sp_group=self.sp_group,
use_ddp=use_ddp, use_ddp=self.use_ddp,
ddp_config=self.ddp_config, ddp_config=self.ddp_config,
custom_policy=self.custom_policy, custom_policy=self.custom_policy,
) )
@ -271,7 +286,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
) )
else: else:
if not (self.dp_size > 1 or self.moe_dp_size > 1): if self.dp_size <= 1:
warnings.warn( warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0." "If you do not intend to use cpu_offload, please consider set zero_stage=0."

View File

@ -10,13 +10,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import is_flash_attn_2_available, logging from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import ( from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
DPGradScalerIn,
DPGradScalerOut,
EPGradScalerIn,
EPGradScalerOut,
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig

View File

@ -118,7 +118,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
selected_experts_idx = selected_experts.argsort() selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts) input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
dist.get_rank()
output_split_sizes = torch.zeros_like(input_split_sizes) output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)

View File

@ -23,7 +23,7 @@ NUM_HEADS = 4
TOP_K = 1 TOP_K = 1
@parameterize("config", [(1, 1, 1)]) @parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]): def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config stage, ep_size, tp_size = config
dtype = torch.float16 dtype = torch.float16

View File

@ -24,11 +24,10 @@ NUM_HEADS = 4
TOP_K = 1 TOP_K = 1
@parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)]) @parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]): def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config stage, ep_size, tp_size = config
dtype = torch.float32 dtype, precision = torch.float16, "fp16"
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank()) torch.cuda.set_device(dist.get_rank())
@ -40,7 +39,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
zero_stage=stage, zero_stage=stage,
overlap_communication=False, overlap_communication=False,
initial_scale=1, initial_scale=1,
precision="fp32", precision=precision,
) )
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
@ -109,7 +108,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
dist.barrier() dist.barrier()
saved_model = MixtralModel.from_pretrained(model_dir).cuda() saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(torch_model, saved_model) check_model_equal(torch_model, saved_model)
dist.barrier() dist.barrier()

View File

@ -26,9 +26,7 @@ top_k = 2
def check_model_equal(model1, model2): def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())): for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
if loose_close(p1, p2, p1.dtype): loose_close(p1, p2, p1.dtype)
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
raise AssertionError(f"Model parameter {name} is not equal")
def get_optimizer_snapshot(optim): def get_optimizer_snapshot(optim):

View File

@ -141,12 +141,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[ [
# { # {
# "tp_size": 1, # "tp_size": 1,
# "pp_size": 2, # "pp_size": 1,
# "num_microbatches": 2, # "num_microbatches": 2,
# "ep_size": 2, # "ep_size": 2,
# "zero_stage": 1, # "zero_stage": 0,
# "overlap_communication": False, # "overlap_communication": False,
# "precision": "fp32", # "precision": "fp16",
# }, # [dp(4)] + [moe_dp(4)] # }, # [dp(4)] + [moe_dp(4)]
# { # {
# "tp_size": 1, # "tp_size": 1,
@ -169,7 +169,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{ # Ulysess + Flash attention { # Ulysess + Flash attention
"tp_size": 1, "tp_size": 1,
"pp_size": 1, "pp_size": 1,
"sp_size": 4, "sp_size": 2,
"ep_size": 1, "ep_size": 1,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all", "sequence_parallelism_mode": "all_to_all",