[test] add mixtral transformer test

moe_sp
hxwang 2024-07-02 09:08:41 +00:00
parent 229db4bc16
commit 6a9164a477
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8
6 changed files with 281 additions and 30 deletions

View File

@ -4,8 +4,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.mixtral.modeling_mixtral import ( from transformers.models.mixtral.modeling_mixtral import (
@ -23,30 +21,34 @@ from colossalai.shardformer.shard.utils import set_tensors_to_none
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, config): def __init__(self, config, ep_group):
self.moe_info = None
super().__init__(config) super().__init__(config)
self.setup_ep(ep_group)
def setup_ep(self, ep_group: ProcessGroup): def setup_ep(self, ep_group: ProcessGroup):
ep_group = ep_group ep_group = ep_group
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group self.ep_group = ep_group
if self.num_experts % self.ep_size != 0:
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
self.num_experts_per_ep = self.num_experts // self.ep_size self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts)) set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters(): for p in self.experts.parameters():
p.ep_group = ep_group p.ep_group = ep_group
@staticmethod @staticmethod
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock": def from_native_module(
module: MixtralSparseMoeBlock, ep_group: ProcessGroup, *args, **kwargs
) -> "EPMixtralSparseMoeBlock":
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock module.__class__ = EPMixtralSparseMoeBlock
# if "ep_group" in kwargs: module.setup_ep(ep_group)
assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
module.setup_ep(kwargs["ep_group"])
return module return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@ -3,28 +3,16 @@ from .bert import *
from .blip2 import * from .blip2 import *
from .bloom import * from .bloom import *
from .chatglm2 import * from .chatglm2 import *
from .command import *
from .falcon import * from .falcon import *
from .gpt import * from .gpt import *
from .gptj import * from .gptj import *
from .llama import * from .llama import *
from .mistral import *
from .mixtral import *
from .opt import * from .opt import *
from .qwen2 import *
from .sam import * from .sam import *
from .t5 import * from .t5 import *
from .vit import * from .vit import *
from .whisper import * from .whisper import *
try:
from .mistral import *
except ImportError:
print("This version of transformers doesn't support mistral.")
try:
from .qwen2 import *
except ImportError:
print("This version of transformers doesn't support qwen2.")
try:
from .command import *
except ImportError:
print("This version of transformers doesn't support Command-R.")

View File

@ -0,0 +1,82 @@
# modified from tests/kit/model_zoo/transformers/mistral.py
import torch
import transformers
from transformers import MixtralConfig
from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence Mixtral
# ===============================
def data_gen():
# Generated from following code snippet
#
# from transformers import AutoModelForCausalLM, AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
# tokenized_input = tokenizer([input], return_tensors="pt")
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data["labels"] = data["input_ids"].clone()
return data
def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data["labels"] = torch.tensor([1], dtype=torch.int64)
return data
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_mixtral_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn = lambda x: x.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
config = MixtralConfig(
hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258
)
if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id
model_zoo.register(
name="transformers_mixtral",
model_fn=lambda: transformers.MixtralModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_mixtral_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_mixtral_for_casual_lm",
model_fn=lambda: transformers.MixtralForCausalLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_mixtral_for_sequence_classification",
model_fn=lambda: transformers.MixtralForSequenceClassification(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@ -10,8 +10,6 @@ from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_moe_epsize_param_dict from colossalai.moe.utils import get_moe_epsize_param_dict
# from colossalai.shardformer.layer.moe import SparseMLP
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group

View File

@ -1,6 +1,6 @@
import copy import copy
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Type
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -117,7 +117,12 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
def build_model_from_hybrid_plugin( def build_model_from_hybrid_plugin(
model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam model_fn: Callable,
loss_fn: Callable,
test_config: Dict[str, Any],
optim_class=Adam,
sharded_optim_class=Adam,
pluggin_cls: Type[HybridParallelPlugin] = HybridParallelPlugin,
): ):
use_lazy_init = False use_lazy_init = False
if "use_lazy_init" in test_config: if "use_lazy_init" in test_config:
@ -149,9 +154,10 @@ def build_model_from_hybrid_plugin(
else: else:
org_optimizer = optim_class(org_model.parameters(), lr=1e-3) org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3) sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn criterion = loss_fn
plugin = HybridParallelPlugin(**test_config) plugin = pluggin_cls(**test_config)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)

View File

@ -0,0 +1,175 @@
# modified from test_shard_mistral.py
import os
import pytest
import torch
import colossalai
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
mixtral_model = unwrap_model(org_model, "MixtralModel", "model")
shard_mixtral_model = unwrap_model(sharded_model, "MixtralModel", "model")
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
col_layer_for_check = ["layers[0].self_attn.o_proj"]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 5e-5, 1e-4
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
mixtral_model,
shard_mixtral_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False,
)
col_layer_grads = get_grad_tensors_for_check(
mixtral_model,
shard_mixtral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 2e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(
mixtral_model,
shard_mixtral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 1,
"pp_size": 1,
"ep_size": 4,
"num_microbatches": 2,
"zero_stage": 0,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"ep_size": 4,
"num_microbatches": 2,
"zero_stage": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"ep_size": 4,
"num_microbatches": 2,
"zero_stage": 2,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_mixtral_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_mixtral")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def check_mixtral(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_mixtral_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_mixtral():
spawn(check_mixtral, 4)
if __name__ == "__main__":
test_mixtral()