diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 03b7bebb1..4b1bd0f47 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1122,6 +1122,10 @@ class HybridParallelPlugin(PipelinePluginBase): else: self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + self.logger.info( + f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}", + ranks=[0, 1, 2, 3, 4, 5, 6, 7], + ) self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, sequence_parallel_process_group=self.sp_group, diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 31b346b10..32673169a 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -147,9 +147,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): world_size = dist.get_world_size() if self.enable_sequence_parallelism: - # if sequence parallelism is enabled, we reuse the same group for ep and sp if self.sequence_parallelism_mode == "all_to_all": - # when sequence parallelism is enabled, ep_group reuses sp_group + # if sequence parallelism is enabled, ep_group reuses sp_group if self.ep_size != self.sp_size: raise ValueError( f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} or turned off when sequence parallelism is enabled" @@ -157,8 +156,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): # since we are reusing sp_group, moe_dp_group will be derived as dp_group self.moe_dp_size = self.dp_size - self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters - self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + self.moe_dp_group = self.dp_group + self.dp_sp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) self.ep_group = self.sp_group self.moe_tp_group = self.tp_group else: @@ -177,6 +176,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.moe_dp_group = None self.ep_group = None self.moe_tp_group = None + self.dp_sp_group = self.dp_group # create submesh for ep, moe_dp, moe_tp ranks_by_pp_stage = self.pg_mesh.get_group_along_axis( @@ -225,8 +225,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) self.logger.info( - f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size=}\n" - f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}", + f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}\n" + f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", ranks=[0], ) @@ -254,7 +254,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): module=model, precision=self.precision, shard_config=self.shard_config, - dp_group=self.dp_group, + dp_group=self.dp_sp_group, tp_group=self.tp_group, sp_group=self.sp_group, use_ddp=self.use_ddp, @@ -302,7 +302,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): use_pipeline=self.enable_pipeline_parallelism, force_overlap_comm=self.force_overlap_comm, param_info=param_info, - dp_process_group=self.dp_group, + dp_process_group=self.dp_sp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, moe_dp_group=self.moe_dp_group, diff --git a/colossalai/legacy/moe/layer/experts.py b/colossalai/legacy/moe/layer/experts.py index c16fc77bb..8088cf44e 100644 --- a/colossalai/legacy/moe/layer/experts.py +++ b/colossalai/legacy/moe/layer/experts.py @@ -7,7 +7,7 @@ import torch.nn as nn from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.legacy.moe.manager import MOE_MANAGER from colossalai.legacy.moe.utils import get_activation -from colossalai.moe.operators import EPGradScalerIn, EPGradScalerOut +from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size diff --git a/colossalai/legacy/moe/layer/layers.py b/colossalai/legacy/moe/layer/layers.py index 8681b5972..e43966f68 100644 --- a/colossalai/legacy/moe/layer/layers.py +++ b/colossalai/legacy/moe/layer/layers.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from colossalai.legacy.moe.load_balance import LoadBalancer from colossalai.legacy.moe.utils import create_ep_hierarchical_group, get_noise_generator -from colossalai.moe.operators import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter +from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter from colossalai.shardformer.layer.moe import MLPExperts from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size diff --git a/colossalai/legacy/moe/layer/routers.py b/colossalai/legacy/moe/layer/routers.py index c16fc77bb..8088cf44e 100644 --- a/colossalai/legacy/moe/layer/routers.py +++ b/colossalai/legacy/moe/layer/routers.py @@ -7,7 +7,7 @@ import torch.nn as nn from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.legacy.moe.manager import MOE_MANAGER from colossalai.legacy.moe.utils import get_activation -from colossalai.moe.operators import EPGradScalerIn, EPGradScalerOut +from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size diff --git a/colossalai/moe/operators.py b/colossalai/moe/_operation.py similarity index 100% rename from colossalai/moe/operators.py rename to colossalai/moe/_operation.py diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index a90cd8726..33fac9b93 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -10,7 +10,13 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import is_flash_attn_2_available, logging from colossalai.lazy import LazyInitContext -from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven +from colossalai.moe._operation import ( + DPGradScalerIn, + DPGradScalerOut, + EPGradScalerIn, + EPGradScalerOut, + all_to_all_uneven, +) from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row from colossalai.shardformer.shard import ShardConfig diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index f51e690d1..90616351a 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -23,7 +23,13 @@ from transformers.models.mixtral.modeling_mixtral import ( from transformers.utils import is_flash_attn_2_available, logging from colossalai.lazy import LazyInitContext -from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven +from colossalai.moe._operation import ( + DPGradScalerIn, + DPGradScalerOut, + EPGradScalerIn, + EPGradScalerOut, + all_to_all_uneven, +) from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import ( all_to_all_comm, @@ -245,6 +251,7 @@ class MixtralPipelineForwards: if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: + print("input_ids", input_ids.shape) batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape @@ -372,16 +379,29 @@ class MixtralPipelineForwards: if output_router_logits and past_router_logits is not None: all_router_logits = past_router_logits + all_router_logits if stage_manager.is_last_stage(): - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, ) - # always return dict for imediate stage - return { - "hidden_states": hidden_states, - "past_router_logits": all_router_logits, - } + else: + if output_router_logits: + return { + "hidden_states": hidden_states, + "past_router_logits": all_router_logits, + } + else: + return { + "hidden_states": hidden_states, + } @staticmethod def mixtral_for_causal_lm_forward( diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 5a67d653d..04d1dcd41 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -34,7 +34,10 @@ class DeepseekPolicy(Policy): policy = {} if self.shard_config.enable_sequence_parallelism: - self.shard_config.enable_sequence_parallelism = False + if self.pipeline_stage_manager is not None: + # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism + # if both are enabled, one of them will be ignored + raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.") raise NotImplementedError( "Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag." ) @@ -136,6 +139,10 @@ class DeepseekPolicy(Policy): """If under pipeline parallel setting, replacing the original forward method of huggingface to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: + if self.shard_config.enable_sequence_parallelism: + # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism + # if both are enabled, one of them will be ignored + raise NotImplementedError("Pipeline parallelism is not supported with sequence parallelism.") stage_manager = self.pipeline_stage_manager if self.model.__class__.__name__ == "DeepseekModel": module = self.model diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 8fed5ee5c..4de982f44 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -62,6 +62,10 @@ class MixtralPolicy(Policy): attribute_replacement=decoder_attribute_replacement, ) if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: + if self.pipeline_stage_manager is not None: + # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism + # if both are enabled, one of them will be ignored + raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.") self.append_or_create_method_replacement( description={ "forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), @@ -69,19 +73,18 @@ class MixtralPolicy(Policy): policy=policy, target_key=attn_cls, ) - if self.pipeline_stage_manager is None: - self.append_or_create_method_replacement( - description={ - "forward": get_mixtral_flash_attention_model_forward( - self.shard_config, - sp_mode=sp_mode, - sp_size=sp_size, - sp_group=sp_group, - ), - }, - policy=policy, - target_key=MixtralModel, - ) + self.append_or_create_method_replacement( + description={ + "forward": get_mixtral_flash_attention_model_forward( + self.shard_config, + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key=MixtralModel, + ) embedding_cls = None if self.shard_config.enable_tensor_parallelism: @@ -202,6 +205,10 @@ class MixtralPolicy(Policy): """If under pipeline parallel setting, replacing the original forward method of huggingface to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: + if self.shard_config.enable_sequence_parallelism: + # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism + # if both are enabled, one of them will be ignored + raise NotImplementedError("Pipeline parallelism is not supported with sequence parallelism.") stage_manager = self.pipeline_stage_manager if self.model.__class__.__name__ == "MixtralModel": module = self.model diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 78c34046a..19d20de2b 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -100,7 +100,7 @@ class BucketStore(BaseStore): return self._grad_in_bucket - def get_flatten_grad(self, dtype=None) -> Tensor: + def get_flatten_grad(self) -> Tensor: """Return the flattened gradients slices in the bucket, the data organization of the flattened tensor: [grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....] diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 8bad6ebec..d7041e682 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -303,7 +303,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for bucket_store in self.pg_to_bucket_store.values(): bucket_store.build_grad_in_bucket() - flat_grads = bucket_store.get_flatten_grad(self._dtype) + flat_grads = bucket_store.get_flatten_grad() flat_grads /= bucket_store.world_size # ready to add other tensors to bucket diff --git a/tests/test_moe/modelling/test_deepseek.py b/tests/test_moe/modelling/test_deepseek.py deleted file mode 100644 index 74c72dd06..000000000 --- a/tests/test_moe/modelling/test_deepseek.py +++ /dev/null @@ -1,133 +0,0 @@ -import os -import shutil -from copy import deepcopy -from typing import Tuple - -import pytest -import torch -import torch.distributed as dist -from transformers import AutoConfig, AutoModel - -import colossalai -from colossalai.booster.booster import Booster -from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.testing.random import seed_all -from tests.test_moe.moe_utils import loose_close -from tests.test_moe.test_moe_checkpoint import check_model_equal - -NUM_BATCH = 4 -NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4 -HIDDEN_SIZE_PER_HEAD = 4 -NUM_HEADS = 4 -TOP_K = 1 - - -@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)]) -def run_zero_with_original_model(config: Tuple[int, ...]): - stage, ep_size, tp_size = config - dtype = torch.float16 - - rank = torch.distributed.get_rank() - torch.cuda.set_device(dist.get_rank()) - - plugin = MoeHybridParallelPlugin( - pp_size=1, - tp_size=tp_size, - moe_tp_size=tp_size, - ep_size=ep_size, - zero_stage=stage, - overlap_communication=False, - initial_scale=1, - precision="fp32", - ) - booster = Booster(plugin=plugin) - - seed_all(10086) - - config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True) - config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS - config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2 - config.num_hidden_layers = 2 - config.num_attention_heads = NUM_HEADS - config.num_key_value_heads = NUM_HEADS - config.n_routed_experts = NUM_EXPERTS - config.num_experts_per_tok = TOP_K - torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype) - - torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) - - zero_model = deepcopy(torch_model).to(dtype) - zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) - - zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) - - # create different input - seed_all(1453 + rank) - - torch_model.train() - zero_model.train() - for _ in range(2): - input_data = torch.rand( - NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True - ).cuda() - dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input - - zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() - zero_optimizer.backward(zero_output) - zero_optimizer.step() - zero_optimizer.zero_grad() - dist.all_reduce(zero_output) - - all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())] - dist.all_gather(all_inputs, input_data) - - torch_output_sum = 0 - for input_data_ in all_inputs: - torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() - torch_output.backward() - torch_output_sum += torch_output.detach() - # avg dp grads - for p in torch_model.parameters(): - if p.grad is not None: - p.grad /= dist.get_world_size() - torch_optimizer.step() - torch_optimizer.zero_grad() - - loose_close(zero_output, torch_output_sum, dtype=dtype) - - # use checkpoint to load sharded zero model - model_dir = "./test_deepseek" - if dist.get_rank() == 0: - os.makedirs(model_dir, exist_ok=True) - - dist.barrier() - - booster.save_model(zero_model, model_dir, shard=True) - - dist.barrier() - - saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda() - check_model_equal(torch_model, saved_model) - - dist.barrier() - if dist.get_rank() == 0: - shutil.rmtree(model_dir) - - print(f"{dist.get_rank()} test passed") - - -def run_dist(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_zero_with_original_model() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [4]) -@rerun_if_address_is_in_use() -def test_mistral(world_size): - spawn(run_dist, world_size) - - -if __name__ == "__main__": - test_mistral(world_size=4) diff --git a/tests/test_moe/modelling/test_mixtral.py b/tests/test_moe/modelling/test_mixtral.py deleted file mode 100644 index 69d9fa5d4..000000000 --- a/tests/test_moe/modelling/test_mixtral.py +++ /dev/null @@ -1,143 +0,0 @@ -import os -import shutil -from copy import deepcopy -from typing import Tuple - -import pytest -import torch -import torch.distributed as dist -from transformers.models.mixtral.configuration_mixtral import MixtralConfig -from transformers.models.mixtral.modeling_mixtral import MixtralModel - -import colossalai -from colossalai.booster.booster import Booster -from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.testing.random import seed_all -from tests.test_moe.moe_utils import loose_close -from tests.test_moe.test_moe_checkpoint import check_model_equal - -NUM_BATCH = 4 -NUM_TOK_PER_BATCH, NUM_EXPERTS = 8, 4 -HIDDEN_SIZE_PER_HEAD = 4 -NUM_HEADS = 4 -TOP_K = 1 - - -@parameterize("config", [(2, 1, 2, 1, 2, 1), (2, 1, 2, 1, 1, 2), (4, 1, 1, 1, 2, 1), (4, 1, 2, 1, 1, 1)]) -def run_zero_with_original_model(config: Tuple[int, ...]): - ep_size, stage, dp_size, pp_size, tp_size, sp_size = config - print(config) - rank = torch.distributed.get_rank() - dtype, precision = torch.float16, "fp16" - torch.cuda.set_device(dist.get_rank()) - - plugin = MoeHybridParallelPlugin( - pp_size=pp_size, - num_microbatches=pp_size, - tp_size=tp_size, - sp_size=sp_size, - ep_size=ep_size, - moe_tp_size=tp_size, - zero_stage=stage, - enable_sequence_parallelism=sp_size > 1, - sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, - overlap_communication=False, - initial_scale=1, - precision=precision, - find_unused_parameters=True, - ) - booster = Booster(plugin=plugin) - - seed_all(10086) - - config = MixtralConfig( - hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, - intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, - num_hidden_layers=2, - num_attention_heads=NUM_HEADS, - num_key_value_heads=NUM_HEADS, - num_local_experts=NUM_EXPERTS, - num_experts_per_tok=TOP_K, - attn_implementation="flash_attention_2", - ) - - torch_model = MixtralModel(config).to(dtype).cuda() - torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) - - zero_model = deepcopy(torch_model).to(dtype) - zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) - - zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) - - # create different input - seed_all(1453 + rank) - - torch_model.train() - zero_model.train() - for _ in range(2): - input_data = torch.rand( - NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True - ).cuda() - - dist.all_reduce(input_data, group=plugin.tp_group) # tp group requires duplicate input - dist.all_reduce(input_data, group=plugin.sp_group) # sp group requires duplicate input - - zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() - zero_optimizer.backward(zero_output) - zero_optimizer.step() - zero_optimizer.zero_grad() - dist.all_reduce(zero_output) - - all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())] - dist.all_gather(all_inputs, input_data) - - torch_output_sum = 0 - for input_data_ in all_inputs: - torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() - torch_output.backward() - torch_output_sum += torch_output.detach() - # avg dp grads - for p in torch_model.parameters(): - if p.grad is not None: - p.grad /= dist.get_world_size() - torch_optimizer.step() - torch_optimizer.zero_grad() - - loose_close(zero_output, torch_output_sum, dtype=dtype) - - # use checkpoint to load sharded zero model - model_dir = "./test_mixtral" - if dist.get_rank() == 0: - os.makedirs(model_dir, exist_ok=True) - - dist.barrier() - - booster.save_model(zero_model, model_dir, shard=True) - - dist.barrier() - - saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype) - check_model_equal(torch_model, saved_model) - - dist.barrier() - if dist.get_rank() == 0: - shutil.rmtree(model_dir) - - print(f"{dist.get_rank()} test passed") - - -def run_dist(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_zero_with_original_model() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [8]) -@rerun_if_address_is_in_use() -def test_mistral(world_size): - spawn(run_dist, world_size) - - -if __name__ == "__main__": - test_mistral(world_size=8) diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index f2c6d206f..c81023988 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -4,7 +4,7 @@ import pytest import torch from colossalai.accelerator import get_accelerator -from colossalai.moe.operators import MoeCombine, MoeDispatch, moe_cumsum +from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum NUM_EXPERTS = 4 BATCH_SIZE = 4 diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py new file mode 100644 index 000000000..96edfb487 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_deepseek.py @@ -0,0 +1,186 @@ +import os +import shutil +from copy import deepcopy +from typing import Tuple + +import pytest +import torch +import torch.distributed +import torch.distributed as dist +from transformers import AutoConfig, AutoModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import loose_close +from tests.test_moe.test_moe_checkpoint import check_model_equal + +NUM_BATCH = 8 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_LAYERS = 4 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 1 + + +# TODO only need to keep one or two cases +@parameterize( + "config", + [ + (2, 1, 1, 4, 1), + # (2, 1, 2, 1, 1), # TODO debug deepseek pp + # (2, 1, 2, 2, 1), # TODO debug deepseek pp + (2, 1, 1, 2, 1), + # (2, 1, 1, 1, 2), # TODO support deepseek sp + # (2, 1, 4, 1, 1), # TODO debug deepseek pp + (4, 1, 1, 1, 1), + (4, 1, 1, 2, 1), + # (4, 1, 2, 1, 1), # TODO debug deepseek pp + ], +) +def run_zero_with_original_model(config: Tuple[int, ...]): + ep_size, stage, pp_size, tp_size, sp_size = config + world_size = dist.get_world_size() + rank = dist.get_rank() + dtype, precision = torch.float16, "fp16" + torch.cuda.set_device(dist.get_rank()) + + print(config) + plugin = MoeHybridParallelPlugin( + pp_size=pp_size, + num_microbatches=pp_size, + tp_size=tp_size, + sp_size=sp_size, + ep_size=ep_size, + moe_tp_size=tp_size, + zero_stage=stage, + enable_sequence_parallelism=sp_size > 1, + sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, + overlap_communication=False, + initial_scale=1, + precision=precision, + find_unused_parameters=True, + ) + dp_size = plugin.dp_size + + booster = Booster(plugin=plugin) + + # init model with the same seed + seed_all(10086) + + assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS" + config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True) + config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS + config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2 + config.num_hidden_layers = 2 + config.num_attention_heads = NUM_HEADS + config.num_key_value_heads = NUM_HEADS + config.n_routed_experts = NUM_EXPERTS + config.num_experts_per_tok = TOP_K + + torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype) + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + + parallel_model = deepcopy(torch_model) + parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1) + parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer) + + # create different input along dp axis + seed_all(1453 + rank) + + torch_model.train() + parallel_model.train() + for _ in range(2): + # gen random input + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check + + dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input + dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input + + # run the model with hybrid parallel + if booster.plugin.stage_manager is not None: + # for test with pp + data_iter = iter([{"inputs_embeds": input_embeddings}]) + sharded_output = booster.execute_pipeline( + data_iter, + parallel_model, + lambda x, y: x[0].mean(), + parallel_optimizer, + return_loss=True, + return_outputs=True, + ) + if booster.plugin.stage_manager.is_last_stage(): + parallel_output = sharded_output["loss"] + else: + parallel_output = torch.tensor(12345.0, device="cuda") + + # broadcast along pp axis + dist.broadcast( + parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group + ) + else: + # for test without pp + parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean() + parallel_optimizer.backward(parallel_output) + parallel_optimizer.step() + parallel_optimizer.zero_grad() + dist.all_reduce(parallel_output, group=plugin.dp_group) + + # =================================================================================== + # run normal model with all dp(different) inputs + all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] + dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + torch_output_sum = 0 + for input_data_ in all_inputs: + torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() + torch_output.backward() + torch_output_sum += torch_output.detach() + # avg dp grads follows zero optimizer + for p in torch_model.parameters(): + if p.grad is not None: + p.grad /= dp_size + torch_optimizer.step() + torch_optimizer.zero_grad() + + loose_close(parallel_output, torch_output_sum, dtype=dtype) + + # use checkpoint to load sharded zero model + model_dir = "./test_mixtral" + if rank == world_size - 1: + os.makedirs(model_dir, exist_ok=True) + + dist.barrier() + booster.save_model(parallel_model, model_dir, shard=True) + dist.barrier() + + saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda() + check_model_equal(torch_model, saved_model) + dist.barrier() + + if rank == world_size - 1: + shutil.rmtree(model_dir) + + print(f"rank {dist.get_rank()} test passed") + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_zero_with_original_model() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [8]) +@rerun_if_address_is_in_use() +def test_mistral(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_mistral(world_size=8) diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 4e9c594d2..e0ef3bfaf 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -1,229 +1,188 @@ -# modified from test_shard_mistral.py import os +import shutil +from copy import deepcopy +from typing import Tuple import pytest import torch +import torch.distributed import torch.distributed as dist -from torch.testing import assert_close +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralModel import colossalai +from colossalai.booster.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.layer.utils import Randomizer -from colossalai.tensor.d_tensor.api import clear_layout_converter -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import ( - build_model_from_hybrid_plugin, - check_all_grad_tensors, - check_loss, - check_output_hidden_state, - check_weight, - get_grad_tensors_for_check, - run_forward_backward_with_hybrid_plugin, - unwrap_model, -) +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import loose_close +from tests.test_moe.test_moe_checkpoint import check_model_equal -os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" - - -def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): - # TODO: SGD failed for full dp - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( - model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD - ) - - org_model = org_model.to(torch.float16) - org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster - ) - stage_manager = booster.plugin.stage_manager - tp_group = booster.plugin.tp_group - - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config["precision"] == "fp32": - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol) - - # unwrap model - mixtral_model = unwrap_model(org_model, "MixtralModel", "model") - shard_mixtral_model = unwrap_model(sharded_model, "MixtralModel", "model") - - row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] - col_layer_for_check = ["layers[0].self_attn.o_proj"] - - # Check the grad when using ZeRO-1 and ZeRO-2 - if ( - # booster.plugin.zero_stage in [1, 2] - booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" - ): - rank = dist.get_rank() - name_to_p = {n: p for n, p in mixtral_model.named_parameters()} - for n, p in shard_mixtral_model.named_parameters(): - zero_grad = sharded_optimizer.get_param_grad(p) - if name_to_p[n].grad is None: - name_to_p[n].grad = torch.zeros_like(name_to_p[n].data) - continue - assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False) - - # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. - grads_to_check = {} - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config["precision"] == "fp32": - atol, rtol = 5e-5, 1e-4 - else: - atol, rtol = 5e-3, 5e-3 - row_layer_grads = get_grad_tensors_for_check( - mixtral_model, - shard_mixtral_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False, - ) - col_layer_grads = get_grad_tensors_for_check( - mixtral_model, - shard_mixtral_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False, - ) - grads_to_check.update(col_layer_grads) - grads_to_check.update(row_layer_grads) - - # check grads - check_all_grad_tensors(grads_to_check) - - for n, p in shard_mixtral_model.named_parameters(): - assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False) - - # optimizer executes step - org_optimizer.step() - sharded_optimizer.step() - - for n, p in shard_mixtral_model.named_parameters(): - assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False) - - # check weights - if stage_manager is None or stage_manager.is_first_stage(): - if test_config["precision"] == "fp32": - atol, rtol = 2e-4, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - try: - check_weight( - mixtral_model, - shard_mixtral_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False, - ) - except Exception as e: - rank = dist.get_rank() - print(f"{rank=}, Failed config: {test_config}") - raise e - - torch.cuda.empty_cache() +NUM_BATCH = 8 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_LAYERS = 4 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 1 +# TODO only need to keep one or two cases @parameterize( - "test_config", + "config", [ - # { - # "tp_size": 1, - # "pp_size": 1, - # "num_microbatches": 2, - # "ep_size": 2, - # "zero_stage": 0, - # "overlap_communication": False, - # "precision": "fp16", - # }, # [dp(4)] + [moe_dp(4)] - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 2, - # "ep_size": 2, - # "zero_stage": 1, - # "overlap_communication": False, - # "precision": "fp32", - # }, # [dp(2) + pp(2)] + [moe_pp(2)] - # { - # "tp_size": 2, - # "pp_size": 2, - # "num_microbatches": 2, - # "ep_size": 2, - # "zero_stage": 1, - # "overlap_communication": False, - # "precision": "fp32", - # }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass - { # Ulysess + Flash attention - "tp_size": 1, - "pp_size": 1, - "sp_size": 2, - "ep_size": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "zero_stage": 1, - "overlap_communication": False, - "precision": "fp16", - "initial_scale": 1, - "find_unused_parameters": True, - }, - # { - # "tp_size": 1, - # "pp_size": 1, - # "ep_size": 2, - # "zero_stage": 0, - # "overlap_communication": False, - # "precision": "fp32", - # }, # [dp(4)] + [ep(2) + moe_tp(2)] - # { - # "tp_size": 1, - # "pp_size": 1, - # "ep_size": 4, - # "overlap_communication": False, - # "zero_stage": 0, - # "precision": "fp32" - # }, # full dp for non-moe and full ep for moe + (2, 1, 1, 4, 1), + (2, 1, 2, 1, 1), + (2, 1, 2, 2, 1), + (2, 1, 1, 2, 1), + (2, 1, 1, 1, 2), + (2, 1, 4, 1, 1), + (4, 1, 1, 1, 1), + (4, 1, 1, 2, 1), + (4, 1, 2, 1, 1), ], ) -def run_mixtral_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_mixtral") +def run_zero_with_original_model(config: Tuple[int, ...]): + ep_size, stage, pp_size, tp_size, sp_size = config + world_size = dist.get_world_size() + rank = dist.get_rank() + dtype, precision = torch.float16, "fp16" + torch.cuda.set_device(dist.get_rank()) - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + plugin = MoeHybridParallelPlugin( + pp_size=pp_size, + num_microbatches=pp_size, + tp_size=tp_size, + sp_size=sp_size, + ep_size=ep_size, + moe_tp_size=tp_size, + zero_stage=stage, + enable_sequence_parallelism=sp_size > 1, + sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, + overlap_communication=False, + initial_scale=1, + precision=precision, + find_unused_parameters=True, + ) + dp_size = plugin.dp_size - clear_layout_converter() - Randomizer.reset_index() - torch.cuda.empty_cache() + booster = Booster(plugin=plugin) + + # init model with the same seed + seed_all(10086) + + assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS" + config = MixtralConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + num_local_experts=NUM_EXPERTS, + num_experts_per_tok=TOP_K, + attn_implementation="flash_attention_2", + ) + + torch_model = MixtralModel(config).to(dtype).cuda() + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + + parallel_model = deepcopy(torch_model) + parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1) + parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer) + + # create different input along dp axis + seed_all(1453 + rank) + + torch_model.train() + parallel_model.train() + for _ in range(2): + # gen random input + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check + + dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input + dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input + + # run the model with hybrid parallel + if booster.plugin.stage_manager is not None: + # for test with pp + data_iter = iter([{"inputs_embeds": input_embeddings}]) + sharded_output = booster.execute_pipeline( + data_iter, + parallel_model, + lambda x, y: x.last_hidden_state.mean(), + parallel_optimizer, + return_loss=True, + return_outputs=True, + ) + if booster.plugin.stage_manager.is_last_stage(): + parallel_output = sharded_output["loss"] + else: + parallel_output = torch.tensor(12345.0, device="cuda") + + # broadcast along pp axis + dist.broadcast( + parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group + ) + else: + # for test without pp + parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean() + parallel_optimizer.backward(parallel_output) + parallel_optimizer.step() + parallel_optimizer.zero_grad() + dist.all_reduce(parallel_output, group=plugin.dp_group) + + # =================================================================================== + # run normal model with all dp(different) inputs + all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] + dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + torch_output_sum = 0 + for input_data_ in all_inputs: + torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() + torch_output.backward() + torch_output_sum += torch_output.detach() + # avg dp grads follows zero optimizer + for p in torch_model.parameters(): + if p.grad is not None: + p.grad /= dp_size + torch_optimizer.step() + torch_optimizer.zero_grad() + + loose_close(parallel_output, torch_output_sum, dtype=dtype) + + # use checkpoint to load sharded zero model + model_dir = "./test_mixtral" + if rank == world_size - 1: + os.makedirs(model_dir, exist_ok=True) + + dist.barrier() + booster.save_model(parallel_model, model_dir, shard=True) + dist.barrier() + + saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype) + check_model_equal(torch_model, saved_model) + dist.barrier() + + if rank == world_size - 1: + shutil.rmtree(model_dir) + + print(f"rank {dist.get_rank()} test passed") -def check_mixtral(rank, world_size, port): - disable_existing_loggers() +def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_mixtral_test() + run_zero_with_original_model() @pytest.mark.dist +@pytest.mark.parametrize("world_size", [8]) @rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_mixtral(): - spawn(check_mixtral, 4) +def test_mistral(world_size): + spawn(run_dist, world_size) if __name__ == "__main__": - test_mixtral() + test_mistral(world_size=8)