[moe] test deepseek

colossalchat
hxwang 2024-07-16 10:10:40 +00:00 committed by Hongxin Liu
parent dc583aa576
commit 74eccac0db
10 changed files with 276 additions and 68 deletions

View File

@ -1,21 +1,27 @@
from typing import List, Optional, Union
from typing import List, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
from torch.nn import CrossEntropyLoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
from colossalai.moe._operation import (
DPGradScalerIn,
DPGradScalerOut,
EPGradScalerIn,
EPGradScalerOut,
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
# copied from modeling_deepseek.py
@ -42,30 +48,60 @@ class AddAuxiliaryLoss(torch.autograd.Function):
class EPDeepseekMoE(nn.Module):
def __init__(self):
super(EPDeepseekMoE, self).__init__()
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_ep(self, ep_group: ProcessGroup):
ep_group = ep_group
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
def setup_process_groups(
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
assert moe_tp_group is not None
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
self.num_experts = self.config.n_routed_experts
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
p.ep_group = ep_group
set_moe_tensor_ep_group(p, ep_group)
# setup moe_dp group
self.moe_dp_group = moe_dp_group
self.moe_dp_size = moe_dp_group.size()
# setup global tp group
self.tp_group = tp_group
# setup moe tp group
self.moe_tp_group = moe_tp_group
if self.moe_tp_group.size() > 1:
for expert in held_experts:
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.moe_tp_group)
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group)
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group)
@staticmethod
def from_native_module(module: Union["DeepseekMoE", "DeepseekMLP"], *args, **kwargs) -> "EPDeepseekMoE":
def from_native_module(
module,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_tp_group: ProcessGroup,
*args,
**kwargs,
) -> "EPDeepseekMoE":
LazyInitContext.materialize(module)
if module.__class__.__name__ == "DeepseekMLP":
return module
module.__class__ = EPDeepseekMoE
assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
module.setup_ep(kwargs["ep_group"])
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -91,15 +127,24 @@ class EPDeepseekMoE(nn.Module):
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
with torch.no_grad():
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
for i in range(1, self.ep_size):
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
activate_experts = (activate_experts > 0).float()
dist.all_reduce(activate_experts, group=self.moe_dp_group)
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
expert = self.experts[self.expert_start_idx]
output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0])
output_states = expert(output_states)
output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0])
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
@ -107,10 +152,16 @@ class EPDeepseekMoE(nn.Module):
if split_states.size(0) == 0: # no token routed to this experts
continue
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = DPGradScalerIn.apply(
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
)
split_states = expert(split_states)
split_states = DPGradScalerOut.apply(
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
)
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
recover_token_idx = torch.empty_like(flat_topk_token_idx)
recover_token_idx[flat_topk_token_idx] = torch.arange(

View File

@ -116,8 +116,6 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
# TODO drop tokens to reduce tp group redundant communication
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
@ -125,24 +123,24 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
if self.num_experts_per_ep == 1:
# no need to split
expert = self.experts[self.expert_start_idx]
output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0].item())
output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0])
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
output_states = expert.w2(output_states)
output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0].item())
output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0])
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
split_states = DPGradScalerIn.apply(
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
)
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = DPGradScalerIn.apply(
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
)
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
split_states = DPGradScalerOut.apply(
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
)
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)

View File

@ -161,7 +161,7 @@ _POLICY_LIST = {
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
),
# Deepseek
"transformers_modules.modeling_deepseek.DeepSeekModel": PolicyLocation(
"transformers_modules.modeling_deepseek.DeepseekModel": PolicyLocation(
file_name="deepseek", class_name="DeepseekModelPolicy"
),
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(

View File

@ -7,6 +7,7 @@ from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -39,16 +40,55 @@ class DeepseekPolicy(Policy):
)
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.")
# tensor parallelism for non-moe params
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
assert (
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by tensor parallel size."
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
if getattr(self.shard_config, "ep_group", None) is not None:
policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
],
)
if self.shard_config.ep_group:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=EPDeepseekMoE,
kwargs={"ep_group": self.shard_config.ep_group},
kwargs={
"ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
"moe_tp_group": self.shard_config.moe_tp_group,
},
)
],
policy=policy,

View File

@ -8,6 +8,7 @@ from torch.nn import Module
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -42,6 +43,13 @@ class MixtralPolicy(Policy):
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if self.shard_config.enable_tensor_parallelism:
# tensor parallelism for non-moe params
assert (
@ -76,13 +84,22 @@ class MixtralPolicy(Policy):
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
SubModuleReplacementDescription( # or replicate?
suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True}
),
],
)
# TODO shard vocab embedding
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=MixtralModel,
)
if self.shard_config.ep_group:
# expert parallel

View File

@ -0,0 +1,133 @@
import os
import shutil
from copy import deepcopy
from typing import Tuple
import pytest
import torch
import torch.distributed as dist
from transformers import AutoConfig, AutoModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import loose_close
from tests.test_moe.test_moe_checkpoint import check_model_equal
NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 1
@parameterize("config", [(1, 1, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config
dtype = torch.float16
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
pp_size=1,
tp_size=tp_size,
moe_tp_size=tp_size,
ep_size=ep_size,
zero_stage=stage,
overlap_communication=False,
initial_scale=1,
precision="fp32",
)
booster = Booster(plugin=plugin)
seed_all(10086)
config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2
config.num_hidden_layers = 2
config.num_attention_heads = NUM_HEADS
config.num_key_value_heads = NUM_HEADS
config.n_routed_experts = NUM_EXPERTS
config.num_experts_per_tok = TOP_K
torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
zero_model = deepcopy(torch_model).to(dtype)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
# create different input
seed_all(1453 + rank)
torch_model.train()
zero_model.train()
for _ in range(2):
input_data = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
zero_optimizer.backward(zero_output)
zero_optimizer.step()
zero_optimizer.zero_grad()
dist.all_reduce(zero_output)
all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())]
dist.all_gather(all_inputs, input_data)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# avg dp grads
for p in torch_model.parameters():
if p.grad is not None:
p.grad /= dist.get_world_size()
torch_optimizer.step()
torch_optimizer.zero_grad()
loose_close(zero_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model
model_dir = "./test_deepseek"
if dist.get_rank() == 0:
os.makedirs(model_dir, exist_ok=True)
dist.barrier()
booster.save_model(zero_model, model_dir, shard=True)
dist.barrier()
saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
check_model_equal(torch_model, saved_model)
dist.barrier()
if dist.get_rank() == 0:
shutil.rmtree(model_dir)
print(f"{dist.get_rank()} test passed")
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_mistral(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mistral(world_size=4)

View File

@ -24,16 +24,6 @@ NUM_HEADS = 4
TOP_K = 1
def split_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
@parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config

View File

@ -16,6 +16,7 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParall
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import parameterize, spawn
from colossalai.testing.utils import spawn
from tests.test_moe.moe_utils import loose_close
tokens, n_experts = 7, 4
hidden_size = 8
@ -25,7 +26,7 @@ top_k = 2
def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
if not torch.equal(p1.half(), p2.half()):
if loose_close(p1, p2, p1.dtype):
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
raise AssertionError(f"Model parameter {name} is not equal")

View File

@ -21,16 +21,6 @@ NUM_HEADS = 4
TOP_K = 2
def split_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
@parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4])
def run_zero_with_original_model(stage: int, ep_size: int):

View File

@ -14,21 +14,12 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import loose_close
NUM_BATCH=4
NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS=2
NUM_HEADS = 2
TOP_K = 1
def split_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
@parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4])
@ -39,12 +30,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
pp_size=1,
tp_size=1,
ep_size=ep_size,
zero_stage=stage,
overlap_communication=False,
initial_scale=1
pp_size=1, tp_size=1, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1
)
booster = Booster(plugin=plugin)
@ -81,7 +67,9 @@ def run_zero_with_original_model(stage: int, ep_size: int):
zero_model.train()
for _ in range(2):
# zero-dp forward
input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda()
input_data = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
# zero-dp backward
zero_optimizer.backward(zero_output)