From dec6e25e998d5511899151a0cff216b54f2dad8a Mon Sep 17 00:00:00 2001 From: botbw Date: Mon, 8 Jul 2024 05:13:49 +0000 Subject: [PATCH] [test] pass mixtral shardformer test --- .../booster/plugin/hybrid_parallel_plugin.py | 6 +++ .../plugin/moe_hybrid_parallel_plugin.py | 4 +- colossalai/shardformer/policies/mixtral.py | 14 ++--- .../test_model/test_shard_mixtral.py | 54 ++++++++++++------- 4 files changed, 51 insertions(+), 27 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 983ddfc97..ddfe0b2d9 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -38,6 +38,7 @@ from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle +from colossalai.logging import get_dist_logger from .pp_plugin_base import PipelinePluginBase @@ -1016,6 +1017,9 @@ class HybridParallelPlugin(PipelinePluginBase): overlap_allgather: bool = False, ) -> None: super().__init__() + + self.logger = get_dist_logger(type(self).__name__) + assert ( dist.get_world_size() % (tp_size * pp_size) == 0 ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" @@ -1064,6 +1068,8 @@ class HybridParallelPlugin(PipelinePluginBase): self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) + self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}") + self.stage_manager = None self.schedule = None self.custom_policy = custom_policy diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 02a87ff11..b2ee9f650 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -24,7 +24,6 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.zero.low_level import LowLevelZeroOptimizer - class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): def __init__( self, @@ -115,6 +114,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis) self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis) + self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}") + # set ep_group after super init # TODO do it in a better way self.shard_config.ep_group = self.ep_group @@ -168,7 +169,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) else: assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." - assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = MoeHybridParallelZeroOptimizer( optimizer, model, diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 98554c906..410515362 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -20,13 +20,15 @@ class MixtralPolicy(Policy): def preprocess(self): if self.shard_config.enable_tensor_parallelism: - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size + raise NotImplementedError + + # # Resize embedding + # vocab_size = self.model.config.vocab_size + # world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + # if vocab_size % world_size != 0: + # new_vocab_size = vocab_size + world_size - vocab_size % world_size + # self.model.resize_token_embeddings(new_vocab_size) return self.model diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 98f7213a3..4a5f3e14d 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -37,6 +37,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + # unwrap model mixtral_model = unwrap_model(org_model, "MixtralModel", "model") shard_mixtral_model = unwrap_model(sharded_model, "MixtralModel", "model") @@ -81,15 +90,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_optimizer.step() sharded_optimizer.step() - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config["precision"] == "fp32": - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config["precision"] == "fp32": @@ -121,16 +121,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 0, "precision": "fp32", }, # pp + ep - # {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 1, "precision": "fp16"}, # full dp for moe and non-moe - # { # moe_dp = 2, non_moe_dp = 4 - # "tp_size": 1, - # "pp_size": 1, - # "ep_size": 2, - # "zero_stage": 1, - # "precision": "fp16", - # }, # moe_dp = 1, non_moe_dp = 4 - # {"tp_size": 1, "pp_size": 1, "ep_size": 4, "zero_stage": 1, "precision": "fp16"}, - # {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 0, "precision": "fp32"}, # full dp for moe and non-moe + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "ep_size": 1, + "zero_stage": 0, + "precision": "fp32", + }, # pp + ep + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "ep_size": 4, + "zero_stage": 0, + "precision": "fp32", + }, # pp + ep + {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 1, "precision": "bf16"}, # full dp for moe and non-moe + { # moe_dp = 2, non_moe_dp = 4 + "tp_size": 1, + "pp_size": 1, + "ep_size": 2, + "zero_stage": 1, + "precision": "fp32", + }, # moe_dp = 1, non_moe_dp = 4 + {"tp_size": 1, "pp_size": 1, "ep_size": 4, "zero_stage": 1, "precision": "fp32"}, # full dp for non-moe and full ep for moe + {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 0, "precision": "fp32"}, # full dp for moe and non-moe ], ) def run_mixtral_test(test_config):