mirror of https://github.com/hpcaitech/ColossalAI
[moe] finalize test (no pp)
parent
2cddeac717
commit
7077d38d5a
|
@ -109,6 +109,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
if ep_size <= 1:
|
||||||
|
raise ValueError("Use HybridParallelPlugin when ep_size <= 1")
|
||||||
|
|
||||||
self.ep_size = ep_size
|
self.ep_size = ep_size
|
||||||
self.moe_tp_size = moe_tp_size
|
self.moe_tp_size = moe_tp_size
|
||||||
|
|
||||||
|
@ -128,12 +131,12 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
self.ddp_config["find_unused_parameters"] = True
|
self.ddp_config["find_unused_parameters"] = True
|
||||||
|
|
||||||
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
||||||
|
# TODO it might make sense to support non-moe with tp on but moe with tp off
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to set ep_size=1 or zero_stage > 0"
|
f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin or set zero_stage > 0"
|
||||||
)
|
)
|
||||||
|
|
||||||
# set ep_group after super().__init__()
|
# set param group in shard config
|
||||||
# TODO do it in a better way
|
|
||||||
self.shard_config.ep_group = self.ep_group
|
self.shard_config.ep_group = self.ep_group
|
||||||
self.shard_config.moe_dp_group = self.moe_dp_group
|
self.shard_config.moe_dp_group = self.moe_dp_group
|
||||||
self.shard_config.moe_tp_group = self.moe_tp_group
|
self.shard_config.moe_tp_group = self.moe_tp_group
|
||||||
|
@ -149,9 +152,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
# when sequence parallelism is enabled, ep_group reuses sp_group
|
# when sequence parallelism is enabled, ep_group reuses sp_group
|
||||||
if self.ep_size != self.sp_size:
|
if self.ep_size != self.sp_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} when sequence parallelism is enabled"
|
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} or turned off when sequence parallelism is enabled"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# since we are reusing sp_group, moe_dp_group will be derived as dp_group
|
||||||
self.moe_dp_size = self.dp_size
|
self.moe_dp_size = self.dp_size
|
||||||
self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters
|
self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters
|
||||||
self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||||
|
@ -165,7 +169,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
else:
|
else:
|
||||||
self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size)
|
self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size)
|
||||||
|
|
||||||
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
|
if self.moe_dp_size * self.pp_size * self.ep_size * self.moe_tp_size != world_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
|
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
|
||||||
)
|
)
|
||||||
|
@ -214,8 +218,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
self.moe_tp_group = group
|
self.moe_tp_group = group
|
||||||
|
|
||||||
if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
|
if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
|
||||||
# NOTE: different tp settings between moe and non moe param require complex comm logic, where all_to_all might not be suitable
|
# NOTE: different tp settings between moe and non moe param are complex to handle
|
||||||
# this assertion implies that dp_size == moe_dp_size * ep_size
|
# we simply reuse tp_group as moe_tp_group, this implies that dp_size == moe_dp_size * ep_size
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size"
|
f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size"
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,28 +18,34 @@ from tests.test_moe.moe_utils import loose_close
|
||||||
from tests.test_moe.test_moe_checkpoint import check_model_equal
|
from tests.test_moe.test_moe_checkpoint import check_model_equal
|
||||||
|
|
||||||
NUM_BATCH = 4
|
NUM_BATCH = 4
|
||||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
|
NUM_TOK_PER_BATCH, NUM_EXPERTS = 8, 4
|
||||||
HIDDEN_SIZE_PER_HEAD = 4
|
HIDDEN_SIZE_PER_HEAD = 4
|
||||||
NUM_HEADS = 4
|
NUM_HEADS = 4
|
||||||
TOP_K = 1
|
TOP_K = 1
|
||||||
|
|
||||||
|
|
||||||
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
|
@parameterize("config", [(2, 1, 2, 1, 2, 1), (2, 1, 2, 1, 1, 2), (4, 1, 1, 1, 2, 1), (4, 1, 2, 1, 1, 1)])
|
||||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
stage, ep_size, tp_size = config
|
ep_size, stage, dp_size, pp_size, tp_size, sp_size = config
|
||||||
dtype, precision = torch.float16, "fp16"
|
print(config)
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
|
dtype, precision = torch.float16, "fp16"
|
||||||
torch.cuda.set_device(dist.get_rank())
|
torch.cuda.set_device(dist.get_rank())
|
||||||
|
|
||||||
plugin = MoeHybridParallelPlugin(
|
plugin = MoeHybridParallelPlugin(
|
||||||
pp_size=1,
|
pp_size=pp_size,
|
||||||
|
num_microbatches=pp_size,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
moe_tp_size=tp_size,
|
sp_size=sp_size,
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
|
moe_tp_size=tp_size,
|
||||||
zero_stage=stage,
|
zero_stage=stage,
|
||||||
|
enable_sequence_parallelism=sp_size > 1,
|
||||||
|
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||||
overlap_communication=False,
|
overlap_communication=False,
|
||||||
initial_scale=1,
|
initial_scale=1,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
|
find_unused_parameters=True,
|
||||||
)
|
)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
@ -53,6 +59,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
num_key_value_heads=NUM_HEADS,
|
num_key_value_heads=NUM_HEADS,
|
||||||
num_local_experts=NUM_EXPERTS,
|
num_local_experts=NUM_EXPERTS,
|
||||||
num_experts_per_tok=TOP_K,
|
num_experts_per_tok=TOP_K,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_model = MixtralModel(config).to(dtype).cuda()
|
torch_model = MixtralModel(config).to(dtype).cuda()
|
||||||
|
@ -72,7 +79,9 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
input_data = torch.rand(
|
input_data = torch.rand(
|
||||||
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||||
).cuda()
|
).cuda()
|
||||||
dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input
|
|
||||||
|
dist.all_reduce(input_data, group=plugin.tp_group) # tp group requires duplicate input
|
||||||
|
dist.all_reduce(input_data, group=plugin.sp_group) # sp group requires duplicate input
|
||||||
|
|
||||||
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
|
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
|
||||||
zero_optimizer.backward(zero_output)
|
zero_optimizer.backward(zero_output)
|
||||||
|
@ -124,11 +133,11 @@ def run_dist(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [4])
|
@pytest.mark.parametrize("world_size", [8])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_mistral(world_size):
|
def test_mistral(world_size):
|
||||||
spawn(run_dist, world_size)
|
spawn(run_dist, world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_mistral(world_size=4)
|
test_mistral(world_size=8)
|
||||||
|
|
Loading…
Reference in New Issue