From 877d94bb8cf763f469ff93b0911d9e05d596a6cf Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 18 Jul 2024 08:37:06 +0000 Subject: [PATCH] [moe] init moe plugin comm setting with sp --- .../plugin/moe_hybrid_parallel_plugin.py | 163 ++++++++++-------- colossalai/shardformer/modeling/deepseek.py | 8 +- colossalai/shardformer/modeling/mixtral.py | 2 +- tests/test_moe/modelling/test_deepseek.py | 2 +- tests/test_moe/modelling/test_mixtral.py | 9 +- tests/test_moe/test_moe_checkpoint.py | 4 +- .../test_model/test_shard_mixtral.py | 8 +- 7 files changed, 101 insertions(+), 95 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 0ad3889ae..fc3340981 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,6 +1,5 @@ import warnings from collections import defaultdict -from copy import deepcopy from types import MethodType from typing import Callable, Optional, OrderedDict, Tuple @@ -106,37 +105,35 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None: if "overlap_communication" not in kwargs: - kwargs["overlap_communication"] = False + kwargs["overlap_communication"] = False # default by true in super class super().__init__(*args, **kwargs) - self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + self.ep_size = ep_size + self.moe_tp_size = moe_tp_size + + self._init_moe_param_comm() + + self.use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( + self.dp_size == 1 + and self.pp_size == 1 + and self.enable_sequence_parallelism + and self.sequence_parallelism_mode == "all_to_all" + ) + if self.use_ddp: warnings.warn( f"Will have to check all params are used in pytorch DDP since not all experts are always activated" ) self.ddp_config["find_unused_parameters"] = True - world_size = dist.get_world_size() - self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size * self.sp_size) - self.ep_size = ep_size - self.moe_tp_size = moe_tp_size + if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group): + raise ValueError( + f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to set ep_size=1 or zero_stage > 0" + ) - if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size: - raise ValueError( - f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}" - ) - - # self._init_moe_param_comm() - - self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0]) - - # set ep_group after super init + # set ep_group after super().__init__() # TODO do it in a better way - self.moe_dp_group = self.pp_group - self.ep_group = self.pp_group - self.moe_tp_group = self.pp_group - self.shard_config.ep_group = self.ep_group self.shard_config.moe_dp_group = self.moe_dp_group self.shard_config.moe_tp_group = self.moe_tp_group @@ -144,48 +141,77 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.force_overlap_comm = force_overlap_comm def _init_moe_param_comm(self): - self.moe_dp_group = None - self.ep_group = None - self.moe_tp_group = None + world_size = dist.get_world_size() - # create submesh for ep, moe_dp, moe_tp - ranks_by_pp_stage = self.pg_mesh.get_group_along_axis( - [self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True - ) + if self.enable_sequence_parallelism: + # if sequence parallelism is enabled, we reuse the same group for ep and sp + if self.sequence_parallelism_mode == "all_to_all": + # when sequence parallelism is enabled, ep_group reuses sp_group + if self.ep_size != self.sp_size: + raise ValueError( + f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} when sequence parallelism is enabled" + ) - global_rank = self.pg_mesh.rank - pp_rank = self.pg_mesh.coordinate(self.pp_axis) + self.moe_dp_size = self.dp_size + self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters + self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + self.ep_group = self.sp_group + self.moe_tp_group = self.tp_group + else: + raise NotImplementedError( + f"sequence_parallelism_mode={self.sequence_parallelism_mode} is not supported" + ) - # create groups from submesh - for stage_idx, stage_rank in enumerate(ranks_by_pp_stage): - # axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp - submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size) + else: + self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size) - # hardcode here since we only have 3 axis - # moe_dp_group - for ep_idx in range(self.ep_size): - for moe_tp_idx in range(self.moe_tp_size): - moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist() - group = dist.new_group(moe_dp_ranks) - if pp_rank == stage_idx and global_rank in moe_dp_ranks: - assert self.moe_dp_group is None - self.moe_dp_group = group - # ep_group - for moe_dp_idx in range(self.moe_dp_size): - for moe_tp_idx in range(self.moe_tp_size): - ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist() - group = dist.new_group(ep_ranks) - if pp_rank == stage_idx and global_rank in ep_ranks: - assert self.ep_group is None - self.ep_group = group - # moe_tp_group - for moe_dp_idx in range(self.moe_dp_size): + if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size: + raise ValueError( + f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}" + ) + + self.moe_dp_group = None + self.ep_group = None + self.moe_tp_group = None + + # create submesh for ep, moe_dp, moe_tp + ranks_by_pp_stage = self.pg_mesh.get_group_along_axis( + [self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True + ) + + global_rank = self.pg_mesh.rank + pp_rank = self.pg_mesh.coordinate(self.pp_axis) + + # create groups from submesh + for stage_idx, stage_rank in enumerate(ranks_by_pp_stage): + # axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp + submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size) + + # hardcode here since we only have 3 axis + # moe_dp_group for ep_idx in range(self.ep_size): - moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist() - group = dist.new_group(moe_tp_ranks) - if pp_rank == stage_idx and global_rank in moe_tp_ranks: - assert self.moe_tp_group is None - self.moe_tp_group = group + for moe_tp_idx in range(self.moe_tp_size): + moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist() + group = dist.new_group(moe_dp_ranks) + if pp_rank == stage_idx and global_rank in moe_dp_ranks: + assert self.moe_dp_group is None + self.moe_dp_group = group + # ep_group + for moe_dp_idx in range(self.moe_dp_size): + for moe_tp_idx in range(self.moe_tp_size): + ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist() + group = dist.new_group(ep_ranks) + if pp_rank == stage_idx and global_rank in ep_ranks: + assert self.ep_group is None + self.ep_group = group + # moe_tp_group + for moe_dp_idx in range(self.moe_dp_size): + for ep_idx in range(self.ep_size): + moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist() + group = dist.new_group(moe_tp_ranks) + if pp_rank == stage_idx and global_rank in moe_tp_ranks: + assert self.moe_tp_group is None + self.moe_tp_group = group if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group): # NOTE: different tp settings between moe and non moe param require complex comm logic, where all_to_all might not be suitable @@ -195,7 +221,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) self.logger.info( - f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", + f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size}\n" + f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}", ranks=[0], ) @@ -215,30 +242,18 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): param_info = get_param_info(optimizer) # TODO: Support Galore + ZeRO - self.zero_stage - deepcopy(self.zero_config) # Replace with distributed implementation if exists optimizer = cast_to_distributed(optimizer) if not isinstance(model, ModelWrapper): - use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( - self.dp_size == 1 - and self.pp_size == 1 - and self.enable_sequence_parallelism - and self.sequence_parallelism_mode == "all_to_all" - ) - if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": - dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) - else: - dp_group = self.dp_group model = HybridParallelModule( module=model, precision=self.precision, shard_config=self.shard_config, - dp_group=dp_group, + dp_group=self.dp_group, tp_group=self.tp_group, sp_group=self.sp_group, - use_ddp=use_ddp, + use_ddp=self.use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, ) @@ -271,7 +286,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): tp_process_group=self.tp_group, ) else: - if not (self.dp_size > 1 or self.moe_dp_size > 1): + if self.dp_size <= 1: warnings.warn( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "If you do not intend to use cpu_offload, please consider set zero_stage=0." diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index 33fac9b93..a90cd8726 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -10,13 +10,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import is_flash_attn_2_available, logging from colossalai.lazy import LazyInitContext -from colossalai.moe._operation import ( - DPGradScalerIn, - DPGradScalerOut, - EPGradScalerIn, - EPGradScalerOut, - all_to_all_uneven, -) +from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row from colossalai.shardformer.shard import ShardConfig diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 2b50f013d..f51e690d1 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -118,7 +118,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): selected_experts_idx = selected_experts.argsort() dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] input_split_sizes = selected_experts.bincount(minlength=self.num_experts) - dist.get_rank() + output_split_sizes = torch.zeros_like(input_split_sizes) dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) diff --git a/tests/test_moe/modelling/test_deepseek.py b/tests/test_moe/modelling/test_deepseek.py index 42daea512..74c72dd06 100644 --- a/tests/test_moe/modelling/test_deepseek.py +++ b/tests/test_moe/modelling/test_deepseek.py @@ -23,7 +23,7 @@ NUM_HEADS = 4 TOP_K = 1 -@parameterize("config", [(1, 1, 1)]) +@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)]) def run_zero_with_original_model(config: Tuple[int, ...]): stage, ep_size, tp_size = config dtype = torch.float16 diff --git a/tests/test_moe/modelling/test_mixtral.py b/tests/test_moe/modelling/test_mixtral.py index 6e6f0b2b5..fe13b5b30 100644 --- a/tests/test_moe/modelling/test_mixtral.py +++ b/tests/test_moe/modelling/test_mixtral.py @@ -24,11 +24,10 @@ NUM_HEADS = 4 TOP_K = 1 -@parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)]) +@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)]) def run_zero_with_original_model(config: Tuple[int, ...]): stage, ep_size, tp_size = config - dtype = torch.float32 - + dtype, precision = torch.float16, "fp16" rank = torch.distributed.get_rank() torch.cuda.set_device(dist.get_rank()) @@ -40,7 +39,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]): zero_stage=stage, overlap_communication=False, initial_scale=1, - precision="fp32", + precision=precision, ) booster = Booster(plugin=plugin) @@ -109,7 +108,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]): dist.barrier() - saved_model = MixtralModel.from_pretrained(model_dir).cuda() + saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype) check_model_equal(torch_model, saved_model) dist.barrier() diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 4bcf701de..1ab52b371 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -26,9 +26,7 @@ top_k = 2 def check_model_equal(model1, model2): assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())): - if loose_close(p1, p2, p1.dtype): - print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}") - raise AssertionError(f"Model parameter {name} is not equal") + loose_close(p1, p2, p1.dtype) def get_optimizer_snapshot(optim): diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index e873f46f7..232e16f3b 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -141,12 +141,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, [ # { # "tp_size": 1, - # "pp_size": 2, + # "pp_size": 1, # "num_microbatches": 2, # "ep_size": 2, - # "zero_stage": 1, + # "zero_stage": 0, # "overlap_communication": False, - # "precision": "fp32", + # "precision": "fp16", # }, # [dp(4)] + [moe_dp(4)] # { # "tp_size": 1, @@ -169,7 +169,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { # Ulysess + Flash attention "tp_size": 1, "pp_size": 1, - "sp_size": 4, + "sp_size": 2, "ep_size": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all",