mirror of https://github.com/hpcaitech/ColossalAI
[moe] test deepseek
parent
335ad3c6fb
commit
8d3d7f3cbd
|
@ -1,21 +1,27 @@
|
|||
from typing import List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import is_flash_attn_2_available, logging
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
|
||||
from colossalai.moe._operation import (
|
||||
DPGradScalerIn,
|
||||
DPGradScalerOut,
|
||||
EPGradScalerIn,
|
||||
EPGradScalerOut,
|
||||
all_to_all_uneven,
|
||||
)
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
||||
|
||||
|
||||
# copied from modeling_deepseek.py
|
||||
|
@ -42,30 +48,60 @@ class AddAuxiliaryLoss(torch.autograd.Function):
|
|||
|
||||
class EPDeepseekMoE(nn.Module):
|
||||
def __init__(self):
|
||||
super(EPDeepseekMoE, self).__init__()
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_ep(self, ep_group: ProcessGroup):
|
||||
ep_group = ep_group
|
||||
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
||||
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
||||
def setup_process_groups(
|
||||
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
|
||||
):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
assert moe_tp_group is not None
|
||||
|
||||
self.ep_size = dist.get_world_size(ep_group)
|
||||
self.ep_rank = dist.get_rank(ep_group)
|
||||
self.num_experts = self.config.n_routed_experts
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||
|
||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||
for p in self.experts.parameters():
|
||||
p.ep_group = ep_group
|
||||
set_moe_tensor_ep_group(p, ep_group)
|
||||
|
||||
# setup moe_dp group
|
||||
self.moe_dp_group = moe_dp_group
|
||||
self.moe_dp_size = moe_dp_group.size()
|
||||
|
||||
# setup global tp group
|
||||
self.tp_group = tp_group
|
||||
|
||||
# setup moe tp group
|
||||
self.moe_tp_group = moe_tp_group
|
||||
if self.moe_tp_group.size() > 1:
|
||||
for expert in held_experts:
|
||||
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.moe_tp_group)
|
||||
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group)
|
||||
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: Union["DeepseekMoE", "DeepseekMLP"], *args, **kwargs) -> "EPDeepseekMoE":
|
||||
def from_native_module(
|
||||
module,
|
||||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
moe_tp_group: ProcessGroup,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> "EPDeepseekMoE":
|
||||
LazyInitContext.materialize(module)
|
||||
if module.__class__.__name__ == "DeepseekMLP":
|
||||
return module
|
||||
module.__class__ = EPDeepseekMoE
|
||||
assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
|
||||
module.setup_ep(kwargs["ep_group"])
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -91,15 +127,24 @@ class EPDeepseekMoE(nn.Module):
|
|||
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
with torch.no_grad():
|
||||
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
|
||||
for i in range(1, self.ep_size):
|
||||
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
|
||||
activate_experts = (activate_experts > 0).float()
|
||||
dist.all_reduce(activate_experts, group=self.moe_dp_group)
|
||||
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
||||
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
||||
|
||||
if output_states.size(0) > 0:
|
||||
if self.num_experts_per_ep == 1:
|
||||
expert = self.experts[self.expert_start_idx]
|
||||
output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0])
|
||||
output_states = expert(output_states)
|
||||
output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0])
|
||||
else:
|
||||
output_states_splits = output_states.split(output_split_sizes.tolist())
|
||||
output_states_list = []
|
||||
|
@ -107,10 +152,16 @@ class EPDeepseekMoE(nn.Module):
|
|||
if split_states.size(0) == 0: # no token routed to this experts
|
||||
continue
|
||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||
split_states = DPGradScalerIn.apply(
|
||||
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
|
||||
)
|
||||
split_states = expert(split_states)
|
||||
split_states = DPGradScalerOut.apply(
|
||||
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
|
||||
)
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
||||
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
recover_token_idx = torch.empty_like(flat_topk_token_idx)
|
||||
recover_token_idx[flat_topk_token_idx] = torch.arange(
|
||||
|
|
|
@ -116,8 +116,6 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
|
||||
# TODO drop tokens to reduce tp group redundant communication
|
||||
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
# compute expert output
|
||||
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
||||
|
@ -125,24 +123,24 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
if self.num_experts_per_ep == 1:
|
||||
# no need to split
|
||||
expert = self.experts[self.expert_start_idx]
|
||||
output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0].item())
|
||||
output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0])
|
||||
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
|
||||
output_states = expert.w2(output_states)
|
||||
output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0].item())
|
||||
output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0])
|
||||
else:
|
||||
output_states_splits = output_states.split(output_split_sizes.tolist())
|
||||
output_states_list = []
|
||||
for i, split_states in enumerate(output_states_splits):
|
||||
if split_states.size(0) == 0:
|
||||
continue
|
||||
split_states = DPGradScalerIn.apply(
|
||||
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
|
||||
)
|
||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||
split_states = DPGradScalerIn.apply(
|
||||
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
|
||||
)
|
||||
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
||||
split_states = expert.w2(split_states)
|
||||
split_states = DPGradScalerOut.apply(
|
||||
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
|
||||
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
|
||||
)
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
|
|
|
@ -161,7 +161,7 @@ _POLICY_LIST = {
|
|||
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
|
||||
),
|
||||
# Deepseek
|
||||
"transformers_modules.modeling_deepseek.DeepSeekModel": PolicyLocation(
|
||||
"transformers_modules.modeling_deepseek.DeepseekModel": PolicyLocation(
|
||||
file_name="deepseek", class_name="DeepseekModelPolicy"
|
||||
),
|
||||
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
|
||||
|
|
|
@ -7,6 +7,7 @@ from torch import Tensor
|
|||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Row
|
||||
from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
@ -39,16 +40,55 @@ class DeepseekPolicy(Policy):
|
|||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.")
|
||||
# tensor parallelism for non-moe params
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||
assert (
|
||||
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
|
||||
if getattr(self.shard_config, "ep_group", None) is not None:
|
||||
policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if self.shard_config.ep_group:
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp",
|
||||
target_module=EPDeepseekMoE,
|
||||
kwargs={"ep_group": self.shard_config.ep_group},
|
||||
kwargs={
|
||||
"ep_group": self.shard_config.ep_group,
|
||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||
"moe_tp_group": self.shard_config.moe_tp_group,
|
||||
},
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
|
|
|
@ -8,6 +8,7 @@ from torch.nn import Module
|
|||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Row
|
||||
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
@ -42,6 +43,13 @@ class MixtralPolicy(Policy):
|
|||
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# tensor parallelism for non-moe params
|
||||
assert (
|
||||
|
@ -76,13 +84,22 @@ class MixtralPolicy(Policy):
|
|||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
SubModuleReplacementDescription( # or replicate?
|
||||
suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True}
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# TODO shard vocab embedding
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=MixtralModel,
|
||||
)
|
||||
|
||||
if self.shard_config.ep_group:
|
||||
# expert parallel
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
import os
|
||||
import shutil
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import loose_close
|
||||
from tests.test_moe.test_moe_checkpoint import check_model_equal
|
||||
|
||||
NUM_BATCH = 4
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
|
||||
HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS = 4
|
||||
TOP_K = 1
|
||||
|
||||
|
||||
@parameterize("config", [(1, 1, 1)])
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
stage, ep_size, tp_size = config
|
||||
dtype = torch.float16
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
tp_size=tp_size,
|
||||
moe_tp_size=tp_size,
|
||||
ep_size=ep_size,
|
||||
zero_stage=stage,
|
||||
overlap_communication=False,
|
||||
initial_scale=1,
|
||||
precision="fp32",
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
seed_all(10086)
|
||||
|
||||
config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
|
||||
config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
|
||||
config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2
|
||||
config.num_hidden_layers = 2
|
||||
config.num_attention_heads = NUM_HEADS
|
||||
config.num_key_value_heads = NUM_HEADS
|
||||
config.n_routed_experts = NUM_EXPERTS
|
||||
config.num_experts_per_tok = TOP_K
|
||||
torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
|
||||
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
|
||||
zero_model = deepcopy(torch_model).to(dtype)
|
||||
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
|
||||
|
||||
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
|
||||
|
||||
# create different input
|
||||
seed_all(1453 + rank)
|
||||
|
||||
torch_model.train()
|
||||
zero_model.train()
|
||||
for _ in range(2):
|
||||
input_data = torch.rand(
|
||||
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||
).cuda()
|
||||
dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input
|
||||
|
||||
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
|
||||
zero_optimizer.backward(zero_output)
|
||||
zero_optimizer.step()
|
||||
zero_optimizer.zero_grad()
|
||||
dist.all_reduce(zero_output)
|
||||
|
||||
all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(all_inputs, input_data)
|
||||
|
||||
torch_output_sum = 0
|
||||
for input_data_ in all_inputs:
|
||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||
torch_output.backward()
|
||||
torch_output_sum += torch_output.detach()
|
||||
# avg dp grads
|
||||
for p in torch_model.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad /= dist.get_world_size()
|
||||
torch_optimizer.step()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
loose_close(zero_output, torch_output_sum, dtype=dtype)
|
||||
|
||||
# use checkpoint to load sharded zero model
|
||||
model_dir = "./test_deepseek"
|
||||
if dist.get_rank() == 0:
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
booster.save_model(zero_model, model_dir, shard=True)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
|
||||
check_model_equal(torch_model, saved_model)
|
||||
|
||||
dist.barrier()
|
||||
if dist.get_rank() == 0:
|
||||
shutil.rmtree(model_dir)
|
||||
|
||||
print(f"{dist.get_rank()} test passed")
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_zero_with_original_model()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mistral(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mistral(world_size=4)
|
|
@ -24,16 +24,6 @@ NUM_HEADS = 4
|
|||
TOP_K = 1
|
||||
|
||||
|
||||
def split_grad(grad, world_size):
|
||||
with torch.no_grad():
|
||||
grad = grad.clone().detach().flatten()
|
||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
splited_grad = grad.split(grad.numel() // world_size)
|
||||
return splited_grad
|
||||
|
||||
|
||||
@parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)])
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
stage, ep_size, tp_size = config
|
||||
|
|
|
@ -16,6 +16,7 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParall
|
|||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.testing import parameterize, spawn
|
||||
from colossalai.testing.utils import spawn
|
||||
from tests.test_moe.moe_utils import loose_close
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
hidden_size = 8
|
||||
|
@ -25,7 +26,7 @@ top_k = 2
|
|||
def check_model_equal(model1, model2):
|
||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
|
||||
if not torch.equal(p1.half(), p2.half()):
|
||||
if loose_close(p1, p2, p1.dtype):
|
||||
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
|
||||
raise AssertionError(f"Model parameter {name} is not equal")
|
||||
|
||||
|
|
|
@ -21,16 +21,6 @@ NUM_HEADS = 4
|
|||
TOP_K = 2
|
||||
|
||||
|
||||
def split_grad(grad, world_size):
|
||||
with torch.no_grad():
|
||||
grad = grad.clone().detach().flatten()
|
||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
splited_grad = grad.split(grad.numel() // world_size)
|
||||
return splited_grad
|
||||
|
||||
|
||||
@parameterize("stage", [1])
|
||||
@parameterize("ep_size", [1, 2, 4])
|
||||
def run_zero_with_original_model(stage: int, ep_size: int):
|
||||
|
|
|
@ -14,21 +14,12 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import loose_close
|
||||
|
||||
NUM_BATCH=4
|
||||
NUM_BATCH = 4
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
|
||||
HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS=2
|
||||
NUM_HEADS = 2
|
||||
TOP_K = 1
|
||||
|
||||
def split_grad(grad, world_size):
|
||||
with torch.no_grad():
|
||||
grad = grad.clone().detach().flatten()
|
||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
splited_grad = grad.split(grad.numel() // world_size)
|
||||
return splited_grad
|
||||
|
||||
|
||||
@parameterize("stage", [1])
|
||||
@parameterize("ep_size", [1, 2, 4])
|
||||
|
@ -39,12 +30,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
|
|||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
ep_size=ep_size,
|
||||
zero_stage=stage,
|
||||
overlap_communication=False,
|
||||
initial_scale=1
|
||||
pp_size=1, tp_size=1, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
|
@ -81,7 +67,9 @@ def run_zero_with_original_model(stage: int, ep_size: int):
|
|||
zero_model.train()
|
||||
for _ in range(2):
|
||||
# zero-dp forward
|
||||
input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda()
|
||||
input_data = torch.rand(
|
||||
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||
).cuda()
|
||||
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output)
|
||||
|
|
Loading…
Reference in New Issue