From 6a9164a477591c9c8eaecbaef64fa111cee3a49c Mon Sep 17 00:00:00 2001 From: hxwang Date: Tue, 2 Jul 2024 09:08:41 +0000 Subject: [PATCH] [test] add mixtral transformer test --- colossalai/shardformer/modeling/mixtral.py | 20 +- tests/kit/model_zoo/transformers/__init__.py | 20 +- tests/kit/model_zoo/transformers/mixtral.py | 82 ++++++++ tests/test_moe/moe_utils.py | 2 - tests/test_shardformer/test_model/_utils.py | 12 +- .../test_model/test_shard_mixtral.py | 175 ++++++++++++++++++ 6 files changed, 281 insertions(+), 30 deletions(-) create mode 100644 tests/kit/model_zoo/transformers/mixtral.py create mode 100644 tests/test_shardformer/test_model/test_shard_mixtral.py diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 2fbc34302..334bd13fc 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -4,8 +4,6 @@ import torch import torch.distributed as dist import torch.nn.functional as F from torch.distributed import ProcessGroup - -# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo from torch.nn import CrossEntropyLoss from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.mixtral.modeling_mixtral import ( @@ -23,30 +21,34 @@ from colossalai.shardformer.shard.utils import set_tensors_to_none class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): - def __init__(self, config): - self.moe_info = None + def __init__(self, config, ep_group): super().__init__(config) + self.setup_ep(ep_group) def setup_ep(self, ep_group: ProcessGroup): ep_group = ep_group self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 - assert self.num_experts % self.ep_size == 0 self.ep_group = ep_group + + if self.num_experts % self.ep_size != 0: + raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") + self.num_experts_per_ep = self.num_experts // self.ep_size self.expert_start_idx = self.ep_rank * self.num_experts_per_ep held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] + set_tensors_to_none(self.experts, exclude=set(held_experts)) for p in self.experts.parameters(): p.ep_group = ep_group @staticmethod - def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock": + def from_native_module( + module: MixtralSparseMoeBlock, ep_group: ProcessGroup, *args, **kwargs + ) -> "EPMixtralSparseMoeBlock": LazyInitContext.materialize(module) module.__class__ = EPMixtralSparseMoeBlock - # if "ep_group" in kwargs: - assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!" - module.setup_ep(kwargs["ep_group"]) + module.setup_ep(ep_group) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 05c17f562..ac5184065 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -3,28 +3,16 @@ from .bert import * from .blip2 import * from .bloom import * from .chatglm2 import * +from .command import * from .falcon import * from .gpt import * from .gptj import * from .llama import * +from .mistral import * +from .mixtral import * from .opt import * +from .qwen2 import * from .sam import * from .t5 import * from .vit import * from .whisper import * - -try: - from .mistral import * -except ImportError: - print("This version of transformers doesn't support mistral.") - -try: - from .qwen2 import * -except ImportError: - print("This version of transformers doesn't support qwen2.") - - -try: - from .command import * -except ImportError: - print("This version of transformers doesn't support Command-R.") diff --git a/tests/kit/model_zoo/transformers/mixtral.py b/tests/kit/model_zoo/transformers/mixtral.py new file mode 100644 index 000000000..b82a4b939 --- /dev/null +++ b/tests/kit/model_zoo/transformers/mixtral.py @@ -0,0 +1,82 @@ +# modified from tests/kit/model_zoo/transformers/mistral.py +import torch +import transformers +from transformers import MixtralConfig + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence Mixtral +# =============================== + + +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoModelForCausalLM, AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1") + # input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement) + # tokenized_input = tokenizer([input], return_tensors="pt") + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data["labels"] = torch.tensor([1], dtype=torch.int64) + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_mixtral_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) +loss_fn = lambda x: x.loss +loss_fn_for_seq_classification = lambda output: output.logits.mean() + +config = MixtralConfig( + hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258 +) + +if hasattr(config, "pad_token_id"): + config.pad_token_id = config.eos_token_id + +model_zoo.register( + name="transformers_mixtral", + model_fn=lambda: transformers.MixtralModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_mixtral_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_mixtral_for_casual_lm", + model_fn=lambda: transformers.MixtralForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_mixtral_for_sequence_classification", + model_fn=lambda: transformers.MixtralForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_seq_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 131932dcb..ba6a0e8a9 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -10,8 +10,6 @@ from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_moe_epsize_param_dict - -# from colossalai.shardformer.layer.moe import SparseMLP from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 1ffcc541a..190fee129 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,6 +1,6 @@ import copy from contextlib import nullcontext -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Type import torch import torch.distributed as dist @@ -117,7 +117,12 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""): def build_model_from_hybrid_plugin( - model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam + model_fn: Callable, + loss_fn: Callable, + test_config: Dict[str, Any], + optim_class=Adam, + sharded_optim_class=Adam, + pluggin_cls: Type[HybridParallelPlugin] = HybridParallelPlugin, ): use_lazy_init = False if "use_lazy_init" in test_config: @@ -149,9 +154,10 @@ def build_model_from_hybrid_plugin( else: org_optimizer = optim_class(org_model.parameters(), lr=1e-3) sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3) + criterion = loss_fn - plugin = HybridParallelPlugin(**test_config) + plugin = pluggin_cls(**test_config) booster = Booster(plugin=plugin) sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py new file mode 100644 index 000000000..bf2d2bb1b --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -0,0 +1,175 @@ +# modified from test_shard_mistral.py +import os + +import pytest +import torch + +import colossalai +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + mixtral_model = unwrap_model(org_model, "MixtralModel", "model") + shard_mixtral_model = unwrap_model(sharded_model, "MixtralModel", "model") + + row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] + col_layer_for_check = ["layers[0].self_attn.o_proj"] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 5e-5, 1e-4 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + mixtral_model, + shard_mixtral_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, + ) + col_layer_grads = get_grad_tensors_for_check( + mixtral_model, + shard_mixtral_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 2e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight( + mixtral_model, + shard_mixtral_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "pp_size": 1, + "ep_size": 4, + "num_microbatches": 2, + "zero_stage": 0, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "ep_size": 4, + "num_microbatches": 2, + "zero_stage": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "ep_size": 4, + "num_microbatches": 2, + "zero_stage": 2, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_mixtral_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_mixtral") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_mixtral(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_mixtral_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_mixtral(): + spawn(check_mixtral, 4) + + +if __name__ == "__main__": + test_mixtral()