From 74eccac0db4f281b14a7042378e4158a4147fc0d Mon Sep 17 00:00:00 2001 From: hxwang Date: Tue, 16 Jul 2024 10:10:40 +0000 Subject: [PATCH] [moe] test deepseek --- colossalai/shardformer/modeling/deepseek.py | 81 +++++++++-- colossalai/shardformer/modeling/mixtral.py | 14 +- .../shardformer/policies/auto_policy.py | 2 +- colossalai/shardformer/policies/deepseek.py | 46 +++++- colossalai/shardformer/policies/mixtral.py | 21 ++- tests/test_moe/modelling/test_deepseek.py | 133 ++++++++++++++++++ tests/test_moe/modelling/test_mixtral.py | 10 -- tests/test_moe/test_moe_checkpoint.py | 3 +- tests/test_moe/test_moe_ep_tp.py | 10 -- tests/test_moe/test_moe_ep_zero.py | 24 +--- 10 files changed, 276 insertions(+), 68 deletions(-) create mode 100644 tests/test_moe/modelling/test_deepseek.py diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index 6e79ce144..33fac9b93 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -1,21 +1,27 @@ -from typing import List, Optional, Union +from typing import List, Optional import torch import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup - -# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo from torch.nn import CrossEntropyLoss from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import is_flash_attn_2_available, logging from colossalai.lazy import LazyInitContext -from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven +from colossalai.moe._operation import ( + DPGradScalerIn, + DPGradScalerOut, + EPGradScalerIn, + EPGradScalerOut, + all_to_all_uneven, +) from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard.utils import set_tensors_to_none +from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group # copied from modeling_deepseek.py @@ -42,30 +48,60 @@ class AddAuxiliaryLoss(torch.autograd.Function): class EPDeepseekMoE(nn.Module): def __init__(self): - super(EPDeepseekMoE, self).__init__() + raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") - def setup_ep(self, ep_group: ProcessGroup): - ep_group = ep_group - self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 - self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 + def setup_process_groups( + self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup + ): + assert tp_group is not None + assert moe_dp_group is not None + assert ep_group is not None + assert moe_tp_group is not None + + self.ep_size = dist.get_world_size(ep_group) + self.ep_rank = dist.get_rank(ep_group) self.num_experts = self.config.n_routed_experts assert self.num_experts % self.ep_size == 0 + self.ep_group = ep_group self.num_experts_per_ep = self.num_experts // self.ep_size self.expert_start_idx = self.ep_rank * self.num_experts_per_ep held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] + set_tensors_to_none(self.experts, exclude=set(held_experts)) for p in self.experts.parameters(): - p.ep_group = ep_group + set_moe_tensor_ep_group(p, ep_group) + + # setup moe_dp group + self.moe_dp_group = moe_dp_group + self.moe_dp_size = moe_dp_group.size() + + # setup global tp group + self.tp_group = tp_group + + # setup moe tp group + self.moe_tp_group = moe_tp_group + if self.moe_tp_group.size() > 1: + for expert in held_experts: + expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.moe_tp_group) + expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group) + expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group) @staticmethod - def from_native_module(module: Union["DeepseekMoE", "DeepseekMLP"], *args, **kwargs) -> "EPDeepseekMoE": + def from_native_module( + module, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + moe_tp_group: ProcessGroup, + *args, + **kwargs, + ) -> "EPDeepseekMoE": LazyInitContext.materialize(module) if module.__class__.__name__ == "DeepseekMLP": return module module.__class__ = EPDeepseekMoE - assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!" - module.setup_ep(kwargs["ep_group"]) + module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -91,15 +127,24 @@ class EPDeepseekMoE(nn.Module): # [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3] dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) + with torch.no_grad(): + activate_experts = output_split_sizes[: self.num_experts_per_ep].clone() + for i in range(1, self.ep_size): + activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep] + activate_experts = (activate_experts > 0).float() + dist.all_reduce(activate_experts, group=self.moe_dp_group) + input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) - output_states = MoeInGradScaler.apply(output_states, self.ep_size) + output_states = EPGradScalerIn.apply(output_states, self.ep_size) if output_states.size(0) > 0: if self.num_experts_per_ep == 1: expert = self.experts[self.expert_start_idx] + output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0]) output_states = expert(output_states) + output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0]) else: output_states_splits = output_states.split(output_split_sizes.tolist()) output_states_list = [] @@ -107,10 +152,16 @@ class EPDeepseekMoE(nn.Module): if split_states.size(0) == 0: # no token routed to this experts continue expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] + split_states = DPGradScalerIn.apply( + split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep] + ) split_states = expert(split_states) + split_states = DPGradScalerOut.apply( + split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep] + ) output_states_list.append(split_states) output_states = torch.cat(output_states_list) - output_states = MoeOutGradScaler.apply(output_states, self.ep_size) + output_states = EPGradScalerOut.apply(output_states, self.ep_size) dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) recover_token_idx = torch.empty_like(flat_topk_token_idx) recover_token_idx[flat_topk_token_idx] = torch.arange( diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 86ef6c959..cfa7da6c0 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -116,8 +116,6 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - # TODO drop tokens to reduce tp group redundant communication - output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) # compute expert output output_states = EPGradScalerIn.apply(output_states, self.ep_size) @@ -125,24 +123,24 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): if self.num_experts_per_ep == 1: # no need to split expert = self.experts[self.expert_start_idx] - output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0].item()) + output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0]) output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states) output_states = expert.w2(output_states) - output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0].item()) + output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0]) else: output_states_splits = output_states.split(output_split_sizes.tolist()) output_states_list = [] for i, split_states in enumerate(output_states_splits): if split_states.size(0) == 0: continue - split_states = DPGradScalerIn.apply( - split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item() - ) expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] + split_states = DPGradScalerIn.apply( + split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep] + ) split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) split_states = expert.w2(split_states) split_states = DPGradScalerOut.apply( - split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item() + split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep] ) output_states_list.append(split_states) output_states = torch.cat(output_states_list) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 1e0af031a..f2533da4b 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -161,7 +161,7 @@ _POLICY_LIST = { file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" ), # Deepseek - "transformers_modules.modeling_deepseek.DeepSeekModel": PolicyLocation( + "transformers_modules.modeling_deepseek.DeepseekModel": PolicyLocation( file_name="deepseek", class_name="DeepseekModelPolicy" ), "transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation( diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 8ebda357b..5a67d653d 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -7,6 +7,7 @@ from torch import Tensor from torch.nn import Module from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.layer.linear import Linear1D_Row from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -39,16 +40,55 @@ class DeepseekPolicy(Policy): ) if self.shard_config.enable_tensor_parallelism: - raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.") + # tensor parallelism for non-moe params + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + assert ( + self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by tensor parallel size." + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.num_key_value_heads": self.model.config.num_key_value_heads + // self.shard_config.tensor_parallel_size, + } - if getattr(self.shard_config, "ep_group", None) is not None: + policy["DeepseekDecoderLayer"] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + ], + ) + + if self.shard_config.ep_group: # expert parallel self.append_or_create_submodule_replacement( description=[ SubModuleReplacementDescription( suffix="mlp", target_module=EPDeepseekMoE, - kwargs={"ep_group": self.shard_config.ep_group}, + kwargs={ + "ep_group": self.shard_config.ep_group, + "tp_group": self.shard_config.tensor_parallel_process_group, + "moe_dp_group": self.shard_config.moe_dp_group, + "moe_tp_group": self.shard_config.moe_tp_group, + }, ) ], policy=policy, diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 4b77a167f..8905b5696 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -8,6 +8,7 @@ from torch.nn import Module from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D from colossalai.shardformer.layer.linear import Linear1D_Row from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -42,6 +43,13 @@ class MixtralPolicy(Policy): "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." ) + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_tensor_parallelism: # tensor parallelism for non-moe params assert ( @@ -76,13 +84,22 @@ class MixtralPolicy(Policy): suffix="self_attn.o_proj", target_module=Linear1D_Row, ), - SubModuleReplacementDescription( + SubModuleReplacementDescription( # or replicate? suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True} ), ], ) - # TODO shard vocab embedding + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=MixtralModel, + ) if self.shard_config.ep_group: # expert parallel diff --git a/tests/test_moe/modelling/test_deepseek.py b/tests/test_moe/modelling/test_deepseek.py new file mode 100644 index 000000000..42daea512 --- /dev/null +++ b/tests/test_moe/modelling/test_deepseek.py @@ -0,0 +1,133 @@ +import os +import shutil +from copy import deepcopy +from typing import Tuple + +import pytest +import torch +import torch.distributed as dist +from transformers import AutoConfig, AutoModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import loose_close +from tests.test_moe.test_moe_checkpoint import check_model_equal + +NUM_BATCH = 4 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 1 + + +@parameterize("config", [(1, 1, 1)]) +def run_zero_with_original_model(config: Tuple[int, ...]): + stage, ep_size, tp_size = config + dtype = torch.float16 + + rank = torch.distributed.get_rank() + torch.cuda.set_device(dist.get_rank()) + + plugin = MoeHybridParallelPlugin( + pp_size=1, + tp_size=tp_size, + moe_tp_size=tp_size, + ep_size=ep_size, + zero_stage=stage, + overlap_communication=False, + initial_scale=1, + precision="fp32", + ) + booster = Booster(plugin=plugin) + + seed_all(10086) + + config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True) + config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS + config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2 + config.num_hidden_layers = 2 + config.num_attention_heads = NUM_HEADS + config.num_key_value_heads = NUM_HEADS + config.n_routed_experts = NUM_EXPERTS + config.num_experts_per_tok = TOP_K + torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype) + + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + + zero_model = deepcopy(torch_model).to(dtype) + zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) + + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + # create different input + seed_all(1453 + rank) + + torch_model.train() + zero_model.train() + for _ in range(2): + input_data = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input + + zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() + zero_optimizer.backward(zero_output) + zero_optimizer.step() + zero_optimizer.zero_grad() + dist.all_reduce(zero_output) + + all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())] + dist.all_gather(all_inputs, input_data) + + torch_output_sum = 0 + for input_data_ in all_inputs: + torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() + torch_output.backward() + torch_output_sum += torch_output.detach() + # avg dp grads + for p in torch_model.parameters(): + if p.grad is not None: + p.grad /= dist.get_world_size() + torch_optimizer.step() + torch_optimizer.zero_grad() + + loose_close(zero_output, torch_output_sum, dtype=dtype) + + # use checkpoint to load sharded zero model + model_dir = "./test_deepseek" + if dist.get_rank() == 0: + os.makedirs(model_dir, exist_ok=True) + + dist.barrier() + + booster.save_model(zero_model, model_dir, shard=True) + + dist.barrier() + + saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda() + check_model_equal(torch_model, saved_model) + + dist.barrier() + if dist.get_rank() == 0: + shutil.rmtree(model_dir) + + print(f"{dist.get_rank()} test passed") + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_zero_with_original_model() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_mistral(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_mistral(world_size=4) diff --git a/tests/test_moe/modelling/test_mixtral.py b/tests/test_moe/modelling/test_mixtral.py index 8309bfb22..6e6f0b2b5 100644 --- a/tests/test_moe/modelling/test_mixtral.py +++ b/tests/test_moe/modelling/test_mixtral.py @@ -24,16 +24,6 @@ NUM_HEADS = 4 TOP_K = 1 -def split_grad(grad, world_size): - with torch.no_grad(): - grad = grad.clone().detach().flatten() - padding_size = (world_size - grad.numel() % world_size) % world_size - if padding_size > 0: - grad = torch.nn.functional.pad(grad, [0, padding_size]) - splited_grad = grad.split(grad.numel() // world_size) - return splited_grad - - @parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)]) def run_zero_with_original_model(config: Tuple[int, ...]): stage, ep_size, tp_size = config diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 6f3c5b299..4bcf701de 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -16,6 +16,7 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParall from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.testing import parameterize, spawn from colossalai.testing.utils import spawn +from tests.test_moe.moe_utils import loose_close tokens, n_experts = 7, 4 hidden_size = 8 @@ -25,7 +26,7 @@ top_k = 2 def check_model_equal(model1, model2): assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())): - if not torch.equal(p1.half(), p2.half()): + if loose_close(p1, p2, p1.dtype): print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}") raise AssertionError(f"Model parameter {name} is not equal") diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index e944a8c0a..29881c9ab 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -21,16 +21,6 @@ NUM_HEADS = 4 TOP_K = 2 -def split_grad(grad, world_size): - with torch.no_grad(): - grad = grad.clone().detach().flatten() - padding_size = (world_size - grad.numel() % world_size) % world_size - if padding_size > 0: - grad = torch.nn.functional.pad(grad, [0, padding_size]) - splited_grad = grad.split(grad.numel() // world_size) - return splited_grad - - @parameterize("stage", [1]) @parameterize("ep_size", [1, 2, 4]) def run_zero_with_original_model(stage: int, ep_size: int): diff --git a/tests/test_moe/test_moe_ep_zero.py b/tests/test_moe/test_moe_ep_zero.py index c5adaad06..40e3bacb3 100644 --- a/tests/test_moe/test_moe_ep_zero.py +++ b/tests/test_moe/test_moe_ep_zero.py @@ -14,21 +14,12 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from tests.test_moe.moe_utils import loose_close -NUM_BATCH=4 +NUM_BATCH = 4 NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4 HIDDEN_SIZE_PER_HEAD = 4 -NUM_HEADS=2 +NUM_HEADS = 2 TOP_K = 1 -def split_grad(grad, world_size): - with torch.no_grad(): - grad = grad.clone().detach().flatten() - padding_size = (world_size - grad.numel() % world_size) % world_size - if padding_size > 0: - grad = torch.nn.functional.pad(grad, [0, padding_size]) - splited_grad = grad.split(grad.numel() // world_size) - return splited_grad - @parameterize("stage", [1]) @parameterize("ep_size", [1, 2, 4]) @@ -39,12 +30,7 @@ def run_zero_with_original_model(stage: int, ep_size: int): torch.cuda.set_device(dist.get_rank()) plugin = MoeHybridParallelPlugin( - pp_size=1, - tp_size=1, - ep_size=ep_size, - zero_stage=stage, - overlap_communication=False, - initial_scale=1 + pp_size=1, tp_size=1, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1 ) booster = Booster(plugin=plugin) @@ -81,7 +67,9 @@ def run_zero_with_original_model(stage: int, ep_size: int): zero_model.train() for _ in range(2): # zero-dp forward - input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda() + input_data = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() # zero-dp backward zero_optimizer.backward(zero_output)