From b5bfeb2efd8ef213a03deaa57a175886fbe3e112 Mon Sep 17 00:00:00 2001 From: botbw Date: Mon, 8 Jul 2024 09:59:46 +0000 Subject: [PATCH] [moe] implement transit between non moe tp and ep --- .../booster/plugin/hybrid_parallel_plugin.py | 2 +- .../plugin/moe_hybrid_parallel_plugin.py | 16 +-- colossalai/moe/_operation.py | 103 ++++++++++++++ colossalai/shardformer/modeling/mixtral.py | 32 +++-- colossalai/shardformer/policies/mixtral.py | 129 ++++++++---------- colossalai/shardformer/shard/shard_config.py | 2 + .../test_model/test_shard_mixtral.py | 49 +++++-- 7 files changed, 233 insertions(+), 100 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index ddfe0b2d9..cad9ca95c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1068,7 +1068,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) - self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}") + self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0]) self.stage_manager = None self.schedule = None diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index b2ee9f650..0b0d50e28 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -30,8 +30,8 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): optimizer: Optimizer, model: Module, use_pipeline: bool, - dp_process_group: ProcessGroup, # the dp pg for comm - moe_dp_group: ProcessGroup, # the moe dp pg for gomm + dp_process_group: ProcessGroup, # dp pg for comm + moe_dp_group: ProcessGroup, # moe dp pg for comm param_info: OrderedDict, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, @@ -44,7 +44,7 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): verbose: bool = False, reduce_bucket_size: int = 1024 * 1024, # communication communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, + overlap_communication: bool = False, partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload forced_dtype: Optional[torch.dtype] = None, @@ -88,7 +88,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): TODO: add docstring """ - def __init__(self, ep_size: int, ep_tp_size: int = 1, *args, **kwargs) -> None: + def __init__(self, ep_size: int, moe_tp_size: int = 1, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 @@ -98,14 +98,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) self.ddp_config["find_unused_parameters"] = True - if ep_tp_size != 1: + if moe_tp_size != 1: raise NotImplementedError world_size = dist.get_world_size() - self.moe_dp_size = world_size // (ep_size * ep_tp_size) + self.moe_dp_size = world_size // (ep_size * moe_tp_size) self.ep_size = ep_size - self.moe_tp_size = ep_tp_size + self.moe_tp_size = moe_tp_size self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size) self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2 @@ -114,7 +114,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis) self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis) - self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}") + self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0]) # set ep_group after super init # TODO do it in a better way diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 3df349182..cad9573fb 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -397,3 +397,106 @@ def all_to_all_uneven( inputs.requires_grad ), "Input must require grad to assure that backward is executed, otherwise it might hang the program." return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap) + + +# =========================================================== +# This code section was modified from +# https://github.com/microsoft/DeepSpeed/blob/3d347276ce80e1a29e777c839d1d7fabe8e5f034/deepspeed/moe/mappings.py + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# The file has been adapted from the following Megatron-LM file: +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/mappings.py +# Git commit hash: 9dc3c42a84aa656f583703cf8b6b4f79f712b796 +# We retain the following copyright from the original files: + +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def _gather_tokens(input_, dim: int, tp_group: ProcessGroup): + """Gather tensors and concatenate them along a dimension""" + + input_ = input_.contiguous() + # Size and dimension. + rank = tp_group.rank() + + tensor_list = [torch.empty_like(input_) for _ in range(tp_group.size())] + tensor_list[rank] = input_ + dist.all_gather(tensor_list, input_, group=tp_group) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +def _drop_tokens(input_, dim: int, tp_group: ProcessGroup): + """Divide a tensor among the tensor parallel ranks""" + + total_chunks = tp_group.size() + this_chunk = tp_group.rank() + assert input_.shape[ + dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})" + chunk_size = input_.shape[dim] // total_chunks + + return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size) + + +class _GatherTokens(torch.autograd.Function): + """All gather tokens among the tensor parallel ranks""" + + @staticmethod + def forward(ctx, input_: torch.Tensor, dim: int, tp_group: ProcessGroup) -> torch.Tensor: + ctx.dim = dim + ctx.tp_group = tp_group + return _gather_tokens(input_, dim, tp_group) + + @staticmethod + def backward(ctx, grad_output): + return _drop_tokens(grad_output, ctx.dim, ctx.tp_group), None, None + + +class _DropTokens(torch.autograd.Function): + "Divide tokens equally among the tensor parallel ranks" + + @staticmethod + def forward(ctx, input_: torch.Tensor, dim: int, tp_group: ProcessGroup) -> torch.Tensor: + ctx.dim = dim + ctx.tp_group = tp_group + return _drop_tokens(input_, dim, tp_group) + + @staticmethod + def backward(ctx, input_: torch.Tensor) -> Tuple[torch.Tensor, None]: + return _gather_tokens(input_, ctx.dim, ctx.tp_group), None, None + + +def gather_tokens(input_, dim: int, tp_group: ProcessGroup): + if tp_group.size() == 1: + # no tensor parallelism for non-experts + return input_ + assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program." + return _GatherTokens.apply(input_, dim) + + +def drop_tokens(input_, dim: int, tp_group: ProcessGroup): + if tp_group.size() == 1: + # no tensor parallelism for non-experts + return input_ + assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program." + return _DropTokens.apply(input_, dim, tp_group) + +# =========================================================== diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 334bd13fc..5d2dc1dc3 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -14,21 +14,21 @@ from transformers.models.mixtral.modeling_mixtral import ( from transformers.utils import is_flash_attn_2_available, logging from colossalai.lazy import LazyInitContext -from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven +from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven, drop_tokens, gather_tokens from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard.utils import set_tensors_to_none class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): - def __init__(self, config, ep_group): + def __init__(self, config, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None): super().__init__(config) - self.setup_ep(ep_group) + self.setup_process_groups(ep_group, tp_group, moe_tp_group) - def setup_ep(self, ep_group: ProcessGroup): - ep_group = ep_group - self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 - self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 + def setup_process_groups(self, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None): + # setup ep group + self.ep_size = dist.get_world_size(ep_group) + self.ep_rank = dist.get_rank(ep_group) self.ep_group = ep_group if self.num_experts % self.ep_size != 0: @@ -42,13 +42,19 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): for p in self.experts.parameters(): p.ep_group = ep_group + # setup global tp group + self.tp_group = tp_group + + # setup moe tp group + self.moe_tp_group = moe_tp_group + @staticmethod def from_native_module( - module: MixtralSparseMoeBlock, ep_group: ProcessGroup, *args, **kwargs + module: MixtralSparseMoeBlock, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None, *args, **kwargs ) -> "EPMixtralSparseMoeBlock": LazyInitContext.materialize(module) module.__class__ = EPMixtralSparseMoeBlock - module.setup_ep(ep_group) + module.setup_process_groups(ep_group) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -72,6 +78,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + + if self.tp_group is not None and self.tp_group.size() > 1: + dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group) + output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) # compute expert output output_states = MoeInGradScaler.apply(output_states, self.ep_size) @@ -94,6 +104,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): output_states = torch.cat(output_states_list) output_states = MoeOutGradScaler.apply(output_states, self.ep_size) dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + + if self.tp_group is not None and self.tp_group.size() > 1: + dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group) + recover_experts_idx = torch.empty_like(selected_experts_idx) recover_experts_idx[selected_experts_idx] = torch.arange( selected_experts_idx.size(0), device=selected_experts_idx.device diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 410515362..14d57c79d 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -8,6 +8,7 @@ from torch.nn import Module from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.layer.linear import Linear1D_Row from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -20,15 +21,15 @@ class MixtralPolicy(Policy): def preprocess(self): if self.shard_config.enable_tensor_parallelism: - raise NotImplementedError - - # # Resize embedding - # vocab_size = self.model.config.vocab_size - # world_size = self.shard_config.tensor_parallel_size + # non-moe params tensor parallelism + + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size - # if vocab_size % world_size != 0: - # new_vocab_size = vocab_size + world_size - vocab_size % world_size - # self.model.resize_token_embeddings(new_vocab_size) + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model @@ -42,74 +43,62 @@ class MixtralPolicy(Policy): ) if self.shard_config.enable_tensor_parallelism: - raise NotImplementedError - # assert ( - # self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 - # ), f"The number of attention heads must be divisible by tensor parallel size." - # assert ( - # self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 - # ), f"The number of key_value heads must be divisible by tensor parallel size." - # decoder_attribute_replacement = { - # "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - # "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - # "self_attn.num_key_value_heads": self.model.config.num_key_value_heads - # // self.shard_config.tensor_parallel_size, - # } - - # policy[MixtralDecoderLayer] = ModulePolicyDescription( - # attribute_replacement=decoder_attribute_replacement, - # sub_module_replacement=[ - # SubModuleReplacementDescription( - # suffix="self_attn.q_proj", - # target_module=Linear1D_Col, - # kwargs={ - # 'process_group': self.shard_config.tensor_parallel_process_group, - # } - # ), - # SubModuleReplacementDescription( - # suffix="self_attn.k_proj", - # target_module=Linear1D_Col, - # kwargs={ - # 'process_group': self.shard_config.tensor_parallel_process_group, - # } - # ), - # SubModuleReplacementDescription( - # suffix="self_attn.v_proj", - # target_module=Linear1D_Col, - # kwargs={ - # 'process_group': self.shard_config.tensor_parallel_process_group, - # } - # ), - # SubModuleReplacementDescription( - # suffix="self_attn.o_proj", - # target_module=Linear1D_Row, - # kwargs={ - # 'process_group': self.shard_config.tensor_parallel_process_group, - # } - # ), - # # SubModuleReplacementDescription( - # # suffix="mlp.gate_proj", - # # target_module=Linear1D_Col, - # # ), - # # SubModuleReplacementDescription( - # # suffix="mlp.up_proj", - # # target_module=Linear1D_Col, - # # ), - # # SubModuleReplacementDescription( - # # suffix="mlp.down_proj", - # # target_module=Linear1D_Row, - # # ), - # ], - # ) - - if getattr(self.shard_config, "ep_group", None) is None: + # tensor parallelism for non-moe params + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + assert ( + self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by tensor parallel size." + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.num_key_value_heads": self.model.config.num_key_value_heads + // self.shard_config.tensor_parallel_size, + } + + policy[MixtralDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + # SubModuleReplacementDescription( # TODO: enable moe tp parallel + # suffix="mlp.gate_proj", + # target_module=Linear1D_Col, + # ), + # SubModuleReplacementDescription( + # suffix="mlp.up_proj", + # target_module=Linear1D_Col, + # ), + # SubModuleReplacementDescription( + # suffix="mlp.down_proj", + # target_module=Linear1D_Row, + # ), + ], + ) + + if self.shard_config.ep_group: # expert parallel self.append_or_create_submodule_replacement( description=[ SubModuleReplacementDescription( suffix="block_sparse_moe", target_module=EPMixtralSparseMoeBlock, - kwargs={"ep_group": self.shard_config.ep_group}, + kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group}, ) ], policy=policy, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index b64300366..d1aebd5b2 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -47,6 +47,8 @@ class ShardConfig: gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) ep_group: Optional[ProcessGroup] = None + moe_tp_group: Optional[ProcessGroup] = None + # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 4a5f3e14d..70b576908 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -113,40 +113,65 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 2, + "pp_size": 1, + "ep_size": 1, + "zero_stage": 2, + "precision": "fp32", + }, # [dp(2) + tp(2)] + [moe_dp(4)] + { + "tp_size": 2, + "pp_size": 1, + "ep_size": 2, + "zero_stage": 2, + "precision": "fp32", + }, # [dp(2) + tp(2)] + [ep(2) + moe_dp(2)] { "tp_size": 1, "pp_size": 2, "num_microbatches": 2, "ep_size": 1, - "zero_stage": 0, + "zero_stage": 2, "precision": "fp32", - }, # pp + ep + }, # [dp(2) + pp(2)] + [moe_dp(4)] { "tp_size": 1, "pp_size": 2, "num_microbatches": 2, "ep_size": 1, - "zero_stage": 0, + "zero_stage": 2, "precision": "fp32", - }, # pp + ep + }, # [dp(2) + pp(2)] + [moe_dp(4)] { "tp_size": 1, "pp_size": 2, "num_microbatches": 2, "ep_size": 4, - "zero_stage": 0, + "zero_stage": 2, "precision": "fp32", - }, # pp + ep - {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 1, "precision": "bf16"}, # full dp for moe and non-moe - { # moe_dp = 2, non_moe_dp = 4 + }, # [dp(2) + pp(2)] + [ep(4))] + { "tp_size": 1, "pp_size": 1, "ep_size": 2, - "zero_stage": 1, + "zero_stage": 2, "precision": "fp32", - }, # moe_dp = 1, non_moe_dp = 4 - {"tp_size": 1, "pp_size": 1, "ep_size": 4, "zero_stage": 1, "precision": "fp32"}, # full dp for non-moe and full ep for moe - {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 0, "precision": "fp32"}, # full dp for moe and non-moe + }, # [dp(4)] + [ep(2) + moe_tp(2)] + { + "tp_size": 1, + "pp_size": 1, + "ep_size": 4, + "zero_stage": 2, + "precision": "fp32" + }, # full dp for non-moe and full ep for moe + { + "tp_size": 1, + "pp_size": 1, + "ep_size": 1, + "zero_stage": 2, + "precision": "fp32" + }, # full dp for moe and non-moe ], ) def run_mixtral_test(test_config):