mirror of https://github.com/hpcaitech/ColossalAI
[test] add mixtral transformer test
parent
229db4bc16
commit
6a9164a477
|
@ -4,8 +4,6 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
|
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
from transformers.models.mixtral.modeling_mixtral import (
|
from transformers.models.mixtral.modeling_mixtral import (
|
||||||
|
@ -23,30 +21,34 @@ from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||||
|
|
||||||
|
|
||||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
def __init__(self, config):
|
def __init__(self, config, ep_group):
|
||||||
self.moe_info = None
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
self.setup_ep(ep_group)
|
||||||
|
|
||||||
def setup_ep(self, ep_group: ProcessGroup):
|
def setup_ep(self, ep_group: ProcessGroup):
|
||||||
ep_group = ep_group
|
ep_group = ep_group
|
||||||
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
||||||
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
||||||
assert self.num_experts % self.ep_size == 0
|
|
||||||
self.ep_group = ep_group
|
self.ep_group = ep_group
|
||||||
|
|
||||||
|
if self.num_experts % self.ep_size != 0:
|
||||||
|
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
|
||||||
|
|
||||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||||
|
|
||||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||||
for p in self.experts.parameters():
|
for p in self.experts.parameters():
|
||||||
p.ep_group = ep_group
|
p.ep_group = ep_group
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
|
def from_native_module(
|
||||||
|
module: MixtralSparseMoeBlock, ep_group: ProcessGroup, *args, **kwargs
|
||||||
|
) -> "EPMixtralSparseMoeBlock":
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
module.__class__ = EPMixtralSparseMoeBlock
|
module.__class__ = EPMixtralSparseMoeBlock
|
||||||
# if "ep_group" in kwargs:
|
module.setup_ep(ep_group)
|
||||||
assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
|
|
||||||
module.setup_ep(kwargs["ep_group"])
|
|
||||||
return module
|
return module
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
|
@ -3,28 +3,16 @@ from .bert import *
|
||||||
from .blip2 import *
|
from .blip2 import *
|
||||||
from .bloom import *
|
from .bloom import *
|
||||||
from .chatglm2 import *
|
from .chatglm2 import *
|
||||||
|
from .command import *
|
||||||
from .falcon import *
|
from .falcon import *
|
||||||
from .gpt import *
|
from .gpt import *
|
||||||
from .gptj import *
|
from .gptj import *
|
||||||
from .llama import *
|
from .llama import *
|
||||||
|
from .mistral import *
|
||||||
|
from .mixtral import *
|
||||||
from .opt import *
|
from .opt import *
|
||||||
|
from .qwen2 import *
|
||||||
from .sam import *
|
from .sam import *
|
||||||
from .t5 import *
|
from .t5 import *
|
||||||
from .vit import *
|
from .vit import *
|
||||||
from .whisper import *
|
from .whisper import *
|
||||||
|
|
||||||
try:
|
|
||||||
from .mistral import *
|
|
||||||
except ImportError:
|
|
||||||
print("This version of transformers doesn't support mistral.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .qwen2 import *
|
|
||||||
except ImportError:
|
|
||||||
print("This version of transformers doesn't support qwen2.")
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .command import *
|
|
||||||
except ImportError:
|
|
||||||
print("This version of transformers doesn't support Command-R.")
|
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
# modified from tests/kit/model_zoo/transformers/mistral.py
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from transformers import MixtralConfig
|
||||||
|
|
||||||
|
from ..registry import ModelAttribute, model_zoo
|
||||||
|
|
||||||
|
# ===============================
|
||||||
|
# Register single-sentence Mixtral
|
||||||
|
# ===============================
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
# Generated from following code snippet
|
||||||
|
#
|
||||||
|
# from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
# tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
|
||||||
|
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
|
||||||
|
# tokenized_input = tokenizer([input], return_tensors="pt")
|
||||||
|
# input_ids = tokenized_input['input_ids']
|
||||||
|
# attention_mask = tokenized_input['attention_mask']
|
||||||
|
input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64)
|
||||||
|
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||||
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_for_lm():
|
||||||
|
# LM data gen
|
||||||
|
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||||
|
data = data_gen()
|
||||||
|
data["labels"] = data["input_ids"].clone()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_for_sequence_classification():
|
||||||
|
# sequence classification data gen
|
||||||
|
data = data_gen()
|
||||||
|
data["labels"] = torch.tensor([1], dtype=torch.int64)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# define output transform function
|
||||||
|
output_transform_fn = lambda x: x
|
||||||
|
|
||||||
|
# define loss function
|
||||||
|
loss_fn_for_mixtral_model = lambda x: torch.nn.functional.mse_loss(
|
||||||
|
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
|
||||||
|
)
|
||||||
|
loss_fn = lambda x: x.loss
|
||||||
|
loss_fn_for_seq_classification = lambda output: output.logits.mean()
|
||||||
|
|
||||||
|
config = MixtralConfig(
|
||||||
|
hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(config, "pad_token_id"):
|
||||||
|
config.pad_token_id = config.eos_token_id
|
||||||
|
|
||||||
|
model_zoo.register(
|
||||||
|
name="transformers_mixtral",
|
||||||
|
model_fn=lambda: transformers.MixtralModel(config),
|
||||||
|
data_gen_fn=data_gen,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_mixtral_model,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
|
)
|
||||||
|
model_zoo.register(
|
||||||
|
name="transformers_mixtral_for_casual_lm",
|
||||||
|
model_fn=lambda: transformers.MixtralForCausalLM(config),
|
||||||
|
data_gen_fn=data_gen_for_lm,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
|
)
|
||||||
|
model_zoo.register(
|
||||||
|
name="transformers_mixtral_for_sequence_classification",
|
||||||
|
model_fn=lambda: transformers.MixtralForSequenceClassification(config),
|
||||||
|
data_gen_fn=data_gen_for_sequence_classification,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_seq_classification,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
|
)
|
|
@ -10,8 +10,6 @@ from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
|
||||||
from colossalai.legacy.registry import GRADIENT_HANDLER
|
from colossalai.legacy.registry import GRADIENT_HANDLER
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
from colossalai.moe.utils import get_moe_epsize_param_dict
|
from colossalai.moe.utils import get_moe_epsize_param_dict
|
||||||
|
|
||||||
# from colossalai.shardformer.layer.moe import SparseMLP
|
|
||||||
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group
|
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import copy
|
import copy
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -117,7 +117,12 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
|
||||||
|
|
||||||
|
|
||||||
def build_model_from_hybrid_plugin(
|
def build_model_from_hybrid_plugin(
|
||||||
model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
|
model_fn: Callable,
|
||||||
|
loss_fn: Callable,
|
||||||
|
test_config: Dict[str, Any],
|
||||||
|
optim_class=Adam,
|
||||||
|
sharded_optim_class=Adam,
|
||||||
|
pluggin_cls: Type[HybridParallelPlugin] = HybridParallelPlugin,
|
||||||
):
|
):
|
||||||
use_lazy_init = False
|
use_lazy_init = False
|
||||||
if "use_lazy_init" in test_config:
|
if "use_lazy_init" in test_config:
|
||||||
|
@ -149,9 +154,10 @@ def build_model_from_hybrid_plugin(
|
||||||
else:
|
else:
|
||||||
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
|
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
|
||||||
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
|
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
criterion = loss_fn
|
criterion = loss_fn
|
||||||
|
|
||||||
plugin = HybridParallelPlugin(**test_config)
|
plugin = pluggin_cls(**test_config)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
|
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
|
||||||
|
|
|
@ -0,0 +1,175 @@
|
||||||
|
# modified from test_shard_mistral.py
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import (
|
||||||
|
build_model_from_hybrid_plugin,
|
||||||
|
check_all_grad_tensors,
|
||||||
|
check_loss,
|
||||||
|
check_weight,
|
||||||
|
get_grad_tensors_for_check,
|
||||||
|
run_forward_backward_with_hybrid_plugin,
|
||||||
|
unwrap_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||||
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||||
|
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
|
||||||
|
)
|
||||||
|
|
||||||
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||||
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||||
|
)
|
||||||
|
|
||||||
|
stage_manager = booster.plugin.stage_manager
|
||||||
|
tp_group = booster.plugin.tp_group
|
||||||
|
|
||||||
|
# unwrap model
|
||||||
|
mixtral_model = unwrap_model(org_model, "MixtralModel", "model")
|
||||||
|
shard_mixtral_model = unwrap_model(sharded_model, "MixtralModel", "model")
|
||||||
|
|
||||||
|
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
|
||||||
|
col_layer_for_check = ["layers[0].self_attn.o_proj"]
|
||||||
|
|
||||||
|
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||||
|
grads_to_check = {}
|
||||||
|
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 5e-5, 1e-4
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
row_layer_grads = get_grad_tensors_for_check(
|
||||||
|
mixtral_model,
|
||||||
|
shard_mixtral_model,
|
||||||
|
row_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=0,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
col_layer_grads = get_grad_tensors_for_check(
|
||||||
|
mixtral_model,
|
||||||
|
shard_mixtral_model,
|
||||||
|
col_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
grads_to_check.update(col_layer_grads)
|
||||||
|
grads_to_check.update(row_layer_grads)
|
||||||
|
|
||||||
|
# optimizer executes step
|
||||||
|
org_optimizer.step()
|
||||||
|
sharded_optimizer.step()
|
||||||
|
|
||||||
|
# check last hidden state & loss
|
||||||
|
if stage_manager is None or stage_manager.is_last_stage():
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 1e-5, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
# check weights
|
||||||
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 2e-4, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
check_weight(
|
||||||
|
mixtral_model,
|
||||||
|
shard_mixtral_model,
|
||||||
|
col_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check grads
|
||||||
|
check_all_grad_tensors(grads_to_check)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 1,
|
||||||
|
"ep_size": 4,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"zero_stage": 0,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 1,
|
||||||
|
"ep_size": 4,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 1,
|
||||||
|
"ep_size": 4,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_mixtral_test(test_config):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_mixtral")
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def check_mixtral(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_mixtral_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_mixtral():
|
||||||
|
spawn(check_mixtral, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_mixtral()
|
Loading…
Reference in New Issue