[moe] full test for deepseek and mixtral (pp + sp to fix)

moe_sp
hxwang 2024-07-19 06:11:11 +00:00
parent 162e2d935c
commit 783aafa327
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8
17 changed files with 430 additions and 517 deletions

View File

@ -1122,6 +1122,10 @@ class HybridParallelPlugin(PipelinePluginBase):
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
self.logger.info(
f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
ranks=[0, 1, 2, 3, 4, 5, 6, 7],
)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group,

View File

@ -147,9 +147,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
world_size = dist.get_world_size()
if self.enable_sequence_parallelism:
# if sequence parallelism is enabled, we reuse the same group for ep and sp
if self.sequence_parallelism_mode == "all_to_all":
# when sequence parallelism is enabled, ep_group reuses sp_group
# if sequence parallelism is enabled, ep_group reuses sp_group
if self.ep_size != self.sp_size:
raise ValueError(
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} or turned off when sequence parallelism is enabled"
@ -157,8 +156,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# since we are reusing sp_group, moe_dp_group will be derived as dp_group
self.moe_dp_size = self.dp_size
self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters
self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.moe_dp_group = self.dp_group
self.dp_sp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.ep_group = self.sp_group
self.moe_tp_group = self.tp_group
else:
@ -177,6 +176,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.moe_dp_group = None
self.ep_group = None
self.moe_tp_group = None
self.dp_sp_group = self.dp_group
# create submesh for ep, moe_dp, moe_tp
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
@ -225,8 +225,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
self.logger.info(
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size=}\n"
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}\n"
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
ranks=[0],
)
@ -254,7 +254,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_group,
dp_group=self.dp_sp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=self.use_ddp,
@ -302,7 +302,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
use_pipeline=self.enable_pipeline_parallelism,
force_overlap_comm=self.force_overlap_comm,
param_info=param_info,
dp_process_group=self.dp_group,
dp_process_group=self.dp_sp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_dp_group=self.moe_dp_group,

View File

@ -7,7 +7,7 @@ import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.legacy.moe.utils import get_activation
from colossalai.moe.operators import EPGradScalerIn, EPGradScalerOut
from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size

View File

@ -9,7 +9,7 @@ import torch.nn.functional as F
from colossalai.legacy.moe.load_balance import LoadBalancer
from colossalai.legacy.moe.utils import create_ep_hierarchical_group, get_noise_generator
from colossalai.moe.operators import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size

View File

@ -7,7 +7,7 @@ import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.legacy.moe.utils import get_activation
from colossalai.moe.operators import EPGradScalerIn, EPGradScalerOut
from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size

View File

@ -10,7 +10,13 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
from colossalai.moe._operation import (
DPGradScalerIn,
DPGradScalerOut,
EPGradScalerIn,
EPGradScalerOut,
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig

View File

@ -23,7 +23,13 @@ from transformers.models.mixtral.modeling_mixtral import (
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
from colossalai.moe._operation import (
DPGradScalerIn,
DPGradScalerOut,
EPGradScalerIn,
EPGradScalerOut,
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
@ -245,6 +251,7 @@ class MixtralPipelineForwards:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
print("input_ids", input_ids.shape)
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
@ -372,16 +379,29 @@ class MixtralPipelineForwards:
if output_router_logits and past_router_logits is not None:
all_router_logits = past_router_logits + all_router_logits
if stage_manager.is_last_stage():
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
else:
if output_router_logits:
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
else:
return {
"hidden_states": hidden_states,
}
@staticmethod
def mixtral_for_causal_lm_forward(

View File

@ -34,7 +34,10 @@ class DeepseekPolicy(Policy):
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
if self.pipeline_stage_manager is not None:
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
# if both are enabled, one of them will be ignored
raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
raise NotImplementedError(
"Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)
@ -136,6 +139,10 @@ class DeepseekPolicy(Policy):
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
if self.shard_config.enable_sequence_parallelism:
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
# if both are enabled, one of them will be ignored
raise NotImplementedError("Pipeline parallelism is not supported with sequence parallelism.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "DeepseekModel":
module = self.model

View File

@ -62,6 +62,10 @@ class MixtralPolicy(Policy):
attribute_replacement=decoder_attribute_replacement,
)
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
if self.pipeline_stage_manager is not None:
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
# if both are enabled, one of them will be ignored
raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
self.append_or_create_method_replacement(
description={
"forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
@ -69,19 +73,18 @@ class MixtralPolicy(Policy):
policy=policy,
target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement(
description={
"forward": get_mixtral_flash_attention_model_forward(
self.shard_config,
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
),
},
policy=policy,
target_key=MixtralModel,
)
self.append_or_create_method_replacement(
description={
"forward": get_mixtral_flash_attention_model_forward(
self.shard_config,
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
),
},
policy=policy,
target_key=MixtralModel,
)
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
@ -202,6 +205,10 @@ class MixtralPolicy(Policy):
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
if self.shard_config.enable_sequence_parallelism:
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
# if both are enabled, one of them will be ignored
raise NotImplementedError("Pipeline parallelism is not supported with sequence parallelism.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "MixtralModel":
module = self.model

View File

@ -100,7 +100,7 @@ class BucketStore(BaseStore):
return self._grad_in_bucket
def get_flatten_grad(self, dtype=None) -> Tensor:
def get_flatten_grad(self) -> Tensor:
"""Return the flattened gradients slices in the bucket, the data organization of the flattened tensor:
[grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....]

View File

@ -303,7 +303,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for bucket_store in self.pg_to_bucket_store.values():
bucket_store.build_grad_in_bucket()
flat_grads = bucket_store.get_flatten_grad(self._dtype)
flat_grads = bucket_store.get_flatten_grad()
flat_grads /= bucket_store.world_size
# ready to add other tensors to bucket

View File

@ -1,133 +0,0 @@
import os
import shutil
from copy import deepcopy
from typing import Tuple
import pytest
import torch
import torch.distributed as dist
from transformers import AutoConfig, AutoModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import loose_close
from tests.test_moe.test_moe_checkpoint import check_model_equal
NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 1
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config
dtype = torch.float16
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
pp_size=1,
tp_size=tp_size,
moe_tp_size=tp_size,
ep_size=ep_size,
zero_stage=stage,
overlap_communication=False,
initial_scale=1,
precision="fp32",
)
booster = Booster(plugin=plugin)
seed_all(10086)
config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2
config.num_hidden_layers = 2
config.num_attention_heads = NUM_HEADS
config.num_key_value_heads = NUM_HEADS
config.n_routed_experts = NUM_EXPERTS
config.num_experts_per_tok = TOP_K
torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
zero_model = deepcopy(torch_model).to(dtype)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
# create different input
seed_all(1453 + rank)
torch_model.train()
zero_model.train()
for _ in range(2):
input_data = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
zero_optimizer.backward(zero_output)
zero_optimizer.step()
zero_optimizer.zero_grad()
dist.all_reduce(zero_output)
all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())]
dist.all_gather(all_inputs, input_data)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# avg dp grads
for p in torch_model.parameters():
if p.grad is not None:
p.grad /= dist.get_world_size()
torch_optimizer.step()
torch_optimizer.zero_grad()
loose_close(zero_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model
model_dir = "./test_deepseek"
if dist.get_rank() == 0:
os.makedirs(model_dir, exist_ok=True)
dist.barrier()
booster.save_model(zero_model, model_dir, shard=True)
dist.barrier()
saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
check_model_equal(torch_model, saved_model)
dist.barrier()
if dist.get_rank() == 0:
shutil.rmtree(model_dir)
print(f"{dist.get_rank()} test passed")
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_mistral(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mistral(world_size=4)

View File

@ -1,143 +0,0 @@
import os
import shutil
from copy import deepcopy
from typing import Tuple
import pytest
import torch
import torch.distributed as dist
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import loose_close
from tests.test_moe.test_moe_checkpoint import check_model_equal
NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 8, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 1
@parameterize("config", [(2, 1, 2, 1, 2, 1), (2, 1, 2, 1, 1, 2), (4, 1, 1, 1, 2, 1), (4, 1, 2, 1, 1, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]):
ep_size, stage, dp_size, pp_size, tp_size, sp_size = config
print(config)
rank = torch.distributed.get_rank()
dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
pp_size=pp_size,
num_microbatches=pp_size,
tp_size=tp_size,
sp_size=sp_size,
ep_size=ep_size,
moe_tp_size=tp_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
overlap_communication=False,
initial_scale=1,
precision=precision,
find_unused_parameters=True,
)
booster = Booster(plugin=plugin)
seed_all(10086)
config = MixtralConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=2,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
num_local_experts=NUM_EXPERTS,
num_experts_per_tok=TOP_K,
attn_implementation="flash_attention_2",
)
torch_model = MixtralModel(config).to(dtype).cuda()
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
zero_model = deepcopy(torch_model).to(dtype)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
# create different input
seed_all(1453 + rank)
torch_model.train()
zero_model.train()
for _ in range(2):
input_data = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(input_data, group=plugin.tp_group) # tp group requires duplicate input
dist.all_reduce(input_data, group=plugin.sp_group) # sp group requires duplicate input
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
zero_optimizer.backward(zero_output)
zero_optimizer.step()
zero_optimizer.zero_grad()
dist.all_reduce(zero_output)
all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())]
dist.all_gather(all_inputs, input_data)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# avg dp grads
for p in torch_model.parameters():
if p.grad is not None:
p.grad /= dist.get_world_size()
torch_optimizer.step()
torch_optimizer.zero_grad()
loose_close(zero_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model
model_dir = "./test_mixtral"
if dist.get_rank() == 0:
os.makedirs(model_dir, exist_ok=True)
dist.barrier()
booster.save_model(zero_model, model_dir, shard=True)
dist.barrier()
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(torch_model, saved_model)
dist.barrier()
if dist.get_rank() == 0:
shutil.rmtree(model_dir)
print(f"{dist.get_rank()} test passed")
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_mistral(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mistral(world_size=8)

View File

@ -4,7 +4,7 @@ import pytest
import torch
from colossalai.accelerator import get_accelerator
from colossalai.moe.operators import MoeCombine, MoeDispatch, moe_cumsum
from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum
NUM_EXPERTS = 4
BATCH_SIZE = 4

View File

@ -0,0 +1,186 @@
import os
import shutil
from copy import deepcopy
from typing import Tuple
import pytest
import torch
import torch.distributed
import torch.distributed as dist
from transformers import AutoConfig, AutoModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import loose_close
from tests.test_moe.test_moe_checkpoint import check_model_equal
NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
NUM_LAYERS = 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 1
# TODO only need to keep one or two cases
@parameterize(
"config",
[
(2, 1, 1, 4, 1),
# (2, 1, 2, 1, 1), # TODO debug deepseek pp
# (2, 1, 2, 2, 1), # TODO debug deepseek pp
(2, 1, 1, 2, 1),
# (2, 1, 1, 1, 2), # TODO support deepseek sp
# (2, 1, 4, 1, 1), # TODO debug deepseek pp
(4, 1, 1, 1, 1),
(4, 1, 1, 2, 1),
# (4, 1, 2, 1, 1), # TODO debug deepseek pp
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
ep_size, stage, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank())
print(config)
plugin = MoeHybridParallelPlugin(
pp_size=pp_size,
num_microbatches=pp_size,
tp_size=tp_size,
sp_size=sp_size,
ep_size=ep_size,
moe_tp_size=tp_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
overlap_communication=False,
initial_scale=1,
precision=precision,
find_unused_parameters=True,
)
dp_size = plugin.dp_size
booster = Booster(plugin=plugin)
# init model with the same seed
seed_all(10086)
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2
config.num_hidden_layers = 2
config.num_attention_heads = NUM_HEADS
config.num_key_value_heads = NUM_HEADS
config.n_routed_experts = NUM_EXPERTS
config.num_experts_per_tok = TOP_K
torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
parallel_model = deepcopy(torch_model)
parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
# create different input along dp axis
seed_all(1453 + rank)
torch_model.train()
parallel_model.train()
for _ in range(2):
# gen random input
input_embeddings = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(
input_embeddings, group=plugin.pp_group
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
# run the model with hybrid parallel
if booster.plugin.stage_manager is not None:
# for test with pp
data_iter = iter([{"inputs_embeds": input_embeddings}])
sharded_output = booster.execute_pipeline(
data_iter,
parallel_model,
lambda x, y: x[0].mean(),
parallel_optimizer,
return_loss=True,
return_outputs=True,
)
if booster.plugin.stage_manager.is_last_stage():
parallel_output = sharded_output["loss"]
else:
parallel_output = torch.tensor(12345.0, device="cuda")
# broadcast along pp axis
dist.broadcast(
parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group
)
else:
# for test without pp
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# avg dp grads follows zero optimizer
for p in torch_model.parameters():
if p.grad is not None:
p.grad /= dp_size
torch_optimizer.step()
torch_optimizer.zero_grad()
loose_close(parallel_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model
model_dir = "./test_mixtral"
if rank == world_size - 1:
os.makedirs(model_dir, exist_ok=True)
dist.barrier()
booster.save_model(parallel_model, model_dir, shard=True)
dist.barrier()
saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
check_model_equal(torch_model, saved_model)
dist.barrier()
if rank == world_size - 1:
shutil.rmtree(model_dir)
print(f"rank {dist.get_rank()} test passed")
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_mistral(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mistral(world_size=8)

View File

@ -1,229 +1,188 @@
# modified from test_shard_mistral.py
import os
import shutil
from copy import deepcopy
from typing import Tuple
import pytest
import torch
import torch.distributed
import torch.distributed as dist
from torch.testing import assert_close
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import loose_close
from tests.test_moe.test_moe_checkpoint import check_model_equal
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# TODO: SGD failed for full dp
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
)
org_model = org_model.to(torch.float16)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol)
# unwrap model
mixtral_model = unwrap_model(org_model, "MixtralModel", "model")
shard_mixtral_model = unwrap_model(sharded_model, "MixtralModel", "model")
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
col_layer_for_check = ["layers[0].self_attn.o_proj"]
# Check the grad when using ZeRO-1 and ZeRO-2
if (
# booster.plugin.zero_stage in [1, 2]
booster.plugin.shard_config.enable_sequence_parallelism
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
rank = dist.get_rank()
name_to_p = {n: p for n, p in mixtral_model.named_parameters()}
for n, p in shard_mixtral_model.named_parameters():
zero_grad = sharded_optimizer.get_param_grad(p)
if name_to_p[n].grad is None:
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
continue
assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False)
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 5e-5, 1e-4
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
mixtral_model,
shard_mixtral_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False,
)
col_layer_grads = get_grad_tensors_for_check(
mixtral_model,
shard_mixtral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# check grads
check_all_grad_tensors(grads_to_check)
for n, p in shard_mixtral_model.named_parameters():
assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
for n, p in shard_mixtral_model.named_parameters():
assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 2e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
try:
check_weight(
mixtral_model,
shard_mixtral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
except Exception as e:
rank = dist.get_rank()
print(f"{rank=}, Failed config: {test_config}")
raise e
torch.cuda.empty_cache()
NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
NUM_LAYERS = 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 1
# TODO only need to keep one or two cases
@parameterize(
"test_config",
"config",
[
# {
# "tp_size": 1,
# "pp_size": 1,
# "num_microbatches": 2,
# "ep_size": 2,
# "zero_stage": 0,
# "overlap_communication": False,
# "precision": "fp16",
# }, # [dp(4)] + [moe_dp(4)]
# {
# "tp_size": 1,
# "pp_size": 2,
# "num_microbatches": 2,
# "ep_size": 2,
# "zero_stage": 1,
# "overlap_communication": False,
# "precision": "fp32",
# }, # [dp(2) + pp(2)] + [moe_pp(2)]
# {
# "tp_size": 2,
# "pp_size": 2,
# "num_microbatches": 2,
# "ep_size": 2,
# "zero_stage": 1,
# "overlap_communication": False,
# "precision": "fp32",
# }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"ep_size": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp16",
"initial_scale": 1,
"find_unused_parameters": True,
},
# {
# "tp_size": 1,
# "pp_size": 1,
# "ep_size": 2,
# "zero_stage": 0,
# "overlap_communication": False,
# "precision": "fp32",
# }, # [dp(4)] + [ep(2) + moe_tp(2)]
# {
# "tp_size": 1,
# "pp_size": 1,
# "ep_size": 4,
# "overlap_communication": False,
# "zero_stage": 0,
# "precision": "fp32"
# }, # full dp for non-moe and full ep for moe
(2, 1, 1, 4, 1),
(2, 1, 2, 1, 1),
(2, 1, 2, 2, 1),
(2, 1, 1, 2, 1),
(2, 1, 1, 1, 2),
(2, 1, 4, 1, 1),
(4, 1, 1, 1, 1),
(4, 1, 1, 2, 1),
(4, 1, 2, 1, 1),
],
)
def run_mixtral_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_mixtral")
def run_zero_with_original_model(config: Tuple[int, ...]):
ep_size, stage, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank())
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
plugin = MoeHybridParallelPlugin(
pp_size=pp_size,
num_microbatches=pp_size,
tp_size=tp_size,
sp_size=sp_size,
ep_size=ep_size,
moe_tp_size=tp_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
overlap_communication=False,
initial_scale=1,
precision=precision,
find_unused_parameters=True,
)
dp_size = plugin.dp_size
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
booster = Booster(plugin=plugin)
# init model with the same seed
seed_all(10086)
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
config = MixtralConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=NUM_LAYERS,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
num_local_experts=NUM_EXPERTS,
num_experts_per_tok=TOP_K,
attn_implementation="flash_attention_2",
)
torch_model = MixtralModel(config).to(dtype).cuda()
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
parallel_model = deepcopy(torch_model)
parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
# create different input along dp axis
seed_all(1453 + rank)
torch_model.train()
parallel_model.train()
for _ in range(2):
# gen random input
input_embeddings = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(
input_embeddings, group=plugin.pp_group
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
# run the model with hybrid parallel
if booster.plugin.stage_manager is not None:
# for test with pp
data_iter = iter([{"inputs_embeds": input_embeddings}])
sharded_output = booster.execute_pipeline(
data_iter,
parallel_model,
lambda x, y: x.last_hidden_state.mean(),
parallel_optimizer,
return_loss=True,
return_outputs=True,
)
if booster.plugin.stage_manager.is_last_stage():
parallel_output = sharded_output["loss"]
else:
parallel_output = torch.tensor(12345.0, device="cuda")
# broadcast along pp axis
dist.broadcast(
parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group
)
else:
# for test without pp
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# avg dp grads follows zero optimizer
for p in torch_model.parameters():
if p.grad is not None:
p.grad /= dp_size
torch_optimizer.step()
torch_optimizer.zero_grad()
loose_close(parallel_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model
model_dir = "./test_mixtral"
if rank == world_size - 1:
os.makedirs(model_dir, exist_ok=True)
dist.barrier()
booster.save_model(parallel_model, model_dir, shard=True)
dist.barrier()
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(torch_model, saved_model)
dist.barrier()
if rank == world_size - 1:
shutil.rmtree(model_dir)
print(f"rank {dist.get_rank()} test passed")
def check_mixtral(rank, world_size, port):
disable_existing_loggers()
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_mixtral_test()
run_zero_with_original_model()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_mixtral():
spawn(check_mixtral, 4)
def test_mistral(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mixtral()
test_mistral(world_size=8)