mirror of https://github.com/hpcaitech/ColossalAI
[moe] implement transit between non moe tp and ep
parent
37443cc7e4
commit
b5bfeb2efd
|
@ -1068,7 +1068,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||||
|
|
||||||
self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}")
|
self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0])
|
||||||
|
|
||||||
self.stage_manager = None
|
self.stage_manager = None
|
||||||
self.schedule = None
|
self.schedule = None
|
||||||
|
|
|
@ -30,8 +30,8 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
model: Module,
|
model: Module,
|
||||||
use_pipeline: bool,
|
use_pipeline: bool,
|
||||||
dp_process_group: ProcessGroup, # the dp pg for comm
|
dp_process_group: ProcessGroup, # dp pg for comm
|
||||||
moe_dp_group: ProcessGroup, # the moe dp pg for gomm
|
moe_dp_group: ProcessGroup, # moe dp pg for comm
|
||||||
param_info: OrderedDict,
|
param_info: OrderedDict,
|
||||||
initial_scale: int = 2**16, # grad scaler config
|
initial_scale: int = 2**16, # grad scaler config
|
||||||
min_scale: int = 1,
|
min_scale: int = 1,
|
||||||
|
@ -44,7 +44,7 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
reduce_bucket_size: int = 1024 * 1024, # communication
|
reduce_bucket_size: int = 1024 * 1024, # communication
|
||||||
communication_dtype: Optional[torch.dtype] = None,
|
communication_dtype: Optional[torch.dtype] = None,
|
||||||
overlap_communication: bool = True,
|
overlap_communication: bool = False,
|
||||||
partition_grad: bool = False, # stage 2 flag
|
partition_grad: bool = False, # stage 2 flag
|
||||||
cpu_offload: bool = False, # cpu offload
|
cpu_offload: bool = False, # cpu offload
|
||||||
forced_dtype: Optional[torch.dtype] = None,
|
forced_dtype: Optional[torch.dtype] = None,
|
||||||
|
@ -88,7 +88,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
TODO: add docstring
|
TODO: add docstring
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ep_size: int, ep_tp_size: int = 1, *args, **kwargs) -> None:
|
def __init__(self, ep_size: int, moe_tp_size: int = 1, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||||
|
@ -98,14 +98,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
)
|
)
|
||||||
self.ddp_config["find_unused_parameters"] = True
|
self.ddp_config["find_unused_parameters"] = True
|
||||||
|
|
||||||
if ep_tp_size != 1:
|
if moe_tp_size != 1:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
self.moe_dp_size = world_size // (ep_size * ep_tp_size)
|
self.moe_dp_size = world_size // (ep_size * moe_tp_size)
|
||||||
self.ep_size = ep_size
|
self.ep_size = ep_size
|
||||||
self.moe_tp_size = ep_tp_size
|
self.moe_tp_size = moe_tp_size
|
||||||
|
|
||||||
self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
||||||
self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2
|
self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2
|
||||||
|
@ -114,7 +114,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
|
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
|
||||||
self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
|
self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
|
||||||
|
|
||||||
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}")
|
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
|
||||||
|
|
||||||
# set ep_group after super init
|
# set ep_group after super init
|
||||||
# TODO do it in a better way
|
# TODO do it in a better way
|
||||||
|
|
|
@ -397,3 +397,106 @@ def all_to_all_uneven(
|
||||||
inputs.requires_grad
|
inputs.requires_grad
|
||||||
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
||||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||||
|
|
||||||
|
|
||||||
|
# ===========================================================
|
||||||
|
# This code section was modified from
|
||||||
|
# https://github.com/microsoft/DeepSpeed/blob/3d347276ce80e1a29e777c839d1d7fabe8e5f034/deepspeed/moe/mappings.py
|
||||||
|
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# DeepSpeed Team
|
||||||
|
|
||||||
|
# The file has been adapted from the following Megatron-LM file:
|
||||||
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/mappings.py
|
||||||
|
# Git commit hash: 9dc3c42a84aa656f583703cf8b6b4f79f712b796
|
||||||
|
# We retain the following copyright from the original files:
|
||||||
|
|
||||||
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
def _gather_tokens(input_, dim: int, tp_group: ProcessGroup):
|
||||||
|
"""Gather tensors and concatenate them along a dimension"""
|
||||||
|
|
||||||
|
input_ = input_.contiguous()
|
||||||
|
# Size and dimension.
|
||||||
|
rank = tp_group.rank()
|
||||||
|
|
||||||
|
tensor_list = [torch.empty_like(input_) for _ in range(tp_group.size())]
|
||||||
|
tensor_list[rank] = input_
|
||||||
|
dist.all_gather(tensor_list, input_, group=tp_group)
|
||||||
|
|
||||||
|
# Note: torch.cat already creates a contiguous tensor.
|
||||||
|
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _drop_tokens(input_, dim: int, tp_group: ProcessGroup):
|
||||||
|
"""Divide a tensor among the tensor parallel ranks"""
|
||||||
|
|
||||||
|
total_chunks = tp_group.size()
|
||||||
|
this_chunk = tp_group.rank()
|
||||||
|
assert input_.shape[
|
||||||
|
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
|
||||||
|
chunk_size = input_.shape[dim] // total_chunks
|
||||||
|
|
||||||
|
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)
|
||||||
|
|
||||||
|
|
||||||
|
class _GatherTokens(torch.autograd.Function):
|
||||||
|
"""All gather tokens among the tensor parallel ranks"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_: torch.Tensor, dim: int, tp_group: ProcessGroup) -> torch.Tensor:
|
||||||
|
ctx.dim = dim
|
||||||
|
ctx.tp_group = tp_group
|
||||||
|
return _gather_tokens(input_, dim, tp_group)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
return _drop_tokens(grad_output, ctx.dim, ctx.tp_group), None, None
|
||||||
|
|
||||||
|
|
||||||
|
class _DropTokens(torch.autograd.Function):
|
||||||
|
"Divide tokens equally among the tensor parallel ranks"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_: torch.Tensor, dim: int, tp_group: ProcessGroup) -> torch.Tensor:
|
||||||
|
ctx.dim = dim
|
||||||
|
ctx.tp_group = tp_group
|
||||||
|
return _drop_tokens(input_, dim, tp_group)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, input_: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||||
|
return _gather_tokens(input_, ctx.dim, ctx.tp_group), None, None
|
||||||
|
|
||||||
|
|
||||||
|
def gather_tokens(input_, dim: int, tp_group: ProcessGroup):
|
||||||
|
if tp_group.size() == 1:
|
||||||
|
# no tensor parallelism for non-experts
|
||||||
|
return input_
|
||||||
|
assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
||||||
|
return _GatherTokens.apply(input_, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def drop_tokens(input_, dim: int, tp_group: ProcessGroup):
|
||||||
|
if tp_group.size() == 1:
|
||||||
|
# no tensor parallelism for non-experts
|
||||||
|
return input_
|
||||||
|
assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
||||||
|
return _DropTokens.apply(input_, dim, tp_group)
|
||||||
|
|
||||||
|
# ===========================================================
|
||||||
|
|
|
@ -14,21 +14,21 @@ from transformers.models.mixtral.modeling_mixtral import (
|
||||||
from transformers.utils import is_flash_attn_2_available, logging
|
from transformers.utils import is_flash_attn_2_available, logging
|
||||||
|
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
|
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven, drop_tokens, gather_tokens
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||||
|
|
||||||
|
|
||||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
def __init__(self, config, ep_group):
|
def __init__(self, config, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.setup_ep(ep_group)
|
self.setup_process_groups(ep_group, tp_group, moe_tp_group)
|
||||||
|
|
||||||
def setup_ep(self, ep_group: ProcessGroup):
|
def setup_process_groups(self, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
|
||||||
ep_group = ep_group
|
# setup ep group
|
||||||
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
self.ep_size = dist.get_world_size(ep_group)
|
||||||
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
self.ep_rank = dist.get_rank(ep_group)
|
||||||
self.ep_group = ep_group
|
self.ep_group = ep_group
|
||||||
|
|
||||||
if self.num_experts % self.ep_size != 0:
|
if self.num_experts % self.ep_size != 0:
|
||||||
|
@ -42,13 +42,19 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
for p in self.experts.parameters():
|
for p in self.experts.parameters():
|
||||||
p.ep_group = ep_group
|
p.ep_group = ep_group
|
||||||
|
|
||||||
|
# setup global tp group
|
||||||
|
self.tp_group = tp_group
|
||||||
|
|
||||||
|
# setup moe tp group
|
||||||
|
self.moe_tp_group = moe_tp_group
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(
|
def from_native_module(
|
||||||
module: MixtralSparseMoeBlock, ep_group: ProcessGroup, *args, **kwargs
|
module: MixtralSparseMoeBlock, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None, *args, **kwargs
|
||||||
) -> "EPMixtralSparseMoeBlock":
|
) -> "EPMixtralSparseMoeBlock":
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
module.__class__ = EPMixtralSparseMoeBlock
|
module.__class__ = EPMixtralSparseMoeBlock
|
||||||
module.setup_ep(ep_group)
|
module.setup_process_groups(ep_group)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -72,6 +78,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
|
|
||||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||||
|
|
||||||
|
if self.tp_group is not None and self.tp_group.size() > 1:
|
||||||
|
dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
|
||||||
|
|
||||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||||
# compute expert output
|
# compute expert output
|
||||||
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
||||||
|
@ -94,6 +104,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
output_states = torch.cat(output_states_list)
|
output_states = torch.cat(output_states_list)
|
||||||
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
||||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||||
|
|
||||||
|
if self.tp_group is not None and self.tp_group.size() > 1:
|
||||||
|
dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
|
||||||
|
|
||||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||||
selected_experts_idx.size(0), device=selected_experts_idx.device
|
selected_experts_idx.size(0), device=selected_experts_idx.device
|
||||||
|
|
|
@ -8,6 +8,7 @@ from torch.nn import Module
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
|
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
|
||||||
|
|
||||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||||
|
from colossalai.shardformer.layer.linear import Linear1D_Row
|
||||||
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
|
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
|
||||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
@ -20,15 +21,15 @@ class MixtralPolicy(Policy):
|
||||||
|
|
||||||
def preprocess(self):
|
def preprocess(self):
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
raise NotImplementedError
|
# non-moe params tensor parallelism
|
||||||
|
|
||||||
# # Resize embedding
|
|
||||||
# vocab_size = self.model.config.vocab_size
|
|
||||||
# world_size = self.shard_config.tensor_parallel_size
|
|
||||||
|
|
||||||
# if vocab_size % world_size != 0:
|
# Resize embedding
|
||||||
# new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
vocab_size = self.model.config.vocab_size
|
||||||
# self.model.resize_token_embeddings(new_vocab_size)
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
|
|
||||||
|
if vocab_size % world_size != 0:
|
||||||
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
@ -42,74 +43,62 @@ class MixtralPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
raise NotImplementedError
|
# tensor parallelism for non-moe params
|
||||||
# assert (
|
assert (
|
||||||
# self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
# ), f"The number of attention heads must be divisible by tensor parallel size."
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
# assert (
|
assert (
|
||||||
# self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
# ), f"The number of key_value heads must be divisible by tensor parallel size."
|
), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||||
# decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
# "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
# "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
# "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
||||||
# // self.shard_config.tensor_parallel_size,
|
// self.shard_config.tensor_parallel_size,
|
||||||
# }
|
}
|
||||||
|
|
||||||
# policy[MixtralDecoderLayer] = ModulePolicyDescription(
|
policy[MixtralDecoderLayer] = ModulePolicyDescription(
|
||||||
# attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
# sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
# SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
# suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
# target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
# kwargs={
|
),
|
||||||
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
SubModuleReplacementDescription(
|
||||||
# }
|
suffix="self_attn.k_proj",
|
||||||
# ),
|
target_module=Linear1D_Col,
|
||||||
# SubModuleReplacementDescription(
|
),
|
||||||
# suffix="self_attn.k_proj",
|
SubModuleReplacementDescription(
|
||||||
# target_module=Linear1D_Col,
|
suffix="self_attn.v_proj",
|
||||||
# kwargs={
|
target_module=Linear1D_Col,
|
||||||
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
),
|
||||||
# }
|
SubModuleReplacementDescription(
|
||||||
# ),
|
suffix="self_attn.o_proj",
|
||||||
# SubModuleReplacementDescription(
|
target_module=Linear1D_Row,
|
||||||
# suffix="self_attn.v_proj",
|
),
|
||||||
# target_module=Linear1D_Col,
|
# SubModuleReplacementDescription( # TODO: enable moe tp parallel
|
||||||
# kwargs={
|
# suffix="mlp.gate_proj",
|
||||||
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
# target_module=Linear1D_Col,
|
||||||
# }
|
# ),
|
||||||
# ),
|
# SubModuleReplacementDescription(
|
||||||
# SubModuleReplacementDescription(
|
# suffix="mlp.up_proj",
|
||||||
# suffix="self_attn.o_proj",
|
# target_module=Linear1D_Col,
|
||||||
# target_module=Linear1D_Row,
|
# ),
|
||||||
# kwargs={
|
# SubModuleReplacementDescription(
|
||||||
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
# suffix="mlp.down_proj",
|
||||||
# }
|
# target_module=Linear1D_Row,
|
||||||
# ),
|
# ),
|
||||||
# # SubModuleReplacementDescription(
|
],
|
||||||
# # suffix="mlp.gate_proj",
|
)
|
||||||
# # target_module=Linear1D_Col,
|
|
||||||
# # ),
|
|
||||||
# # SubModuleReplacementDescription(
|
|
||||||
# # suffix="mlp.up_proj",
|
|
||||||
# # target_module=Linear1D_Col,
|
|
||||||
# # ),
|
|
||||||
# # SubModuleReplacementDescription(
|
|
||||||
# # suffix="mlp.down_proj",
|
|
||||||
# # target_module=Linear1D_Row,
|
|
||||||
# # ),
|
|
||||||
# ],
|
|
||||||
# )
|
|
||||||
|
|
||||||
if getattr(self.shard_config, "ep_group", None) is None:
|
if self.shard_config.ep_group:
|
||||||
# expert parallel
|
# expert parallel
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="block_sparse_moe",
|
suffix="block_sparse_moe",
|
||||||
target_module=EPMixtralSparseMoeBlock,
|
target_module=EPMixtralSparseMoeBlock,
|
||||||
kwargs={"ep_group": self.shard_config.ep_group},
|
kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
|
|
@ -47,6 +47,8 @@ class ShardConfig:
|
||||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
||||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||||
ep_group: Optional[ProcessGroup] = None
|
ep_group: Optional[ProcessGroup] = None
|
||||||
|
moe_tp_group: Optional[ProcessGroup] = None
|
||||||
|
|
||||||
# pipeline_parallel_size: int
|
# pipeline_parallel_size: int
|
||||||
# data_parallel_size: int
|
# data_parallel_size: int
|
||||||
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||||
|
|
|
@ -114,39 +114,64 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"tp_size": 1,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 1,
|
||||||
"num_microbatches": 2,
|
|
||||||
"ep_size": 1,
|
"ep_size": 1,
|
||||||
"zero_stage": 0,
|
"zero_stage": 2,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
}, # pp + ep
|
}, # [dp(2) + tp(2)] + [moe_dp(4)]
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"ep_size": 2,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp32",
|
||||||
|
}, # [dp(2) + tp(2)] + [ep(2) + moe_dp(2)]
|
||||||
{
|
{
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 2,
|
"num_microbatches": 2,
|
||||||
"ep_size": 1,
|
"ep_size": 1,
|
||||||
"zero_stage": 0,
|
"zero_stage": 2,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
}, # pp + ep
|
}, # [dp(2) + pp(2)] + [moe_dp(4)]
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"ep_size": 1,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp32",
|
||||||
|
}, # [dp(2) + pp(2)] + [moe_dp(4)]
|
||||||
{
|
{
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 2,
|
"num_microbatches": 2,
|
||||||
"ep_size": 4,
|
"ep_size": 4,
|
||||||
"zero_stage": 0,
|
"zero_stage": 2,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
}, # pp + ep
|
}, # [dp(2) + pp(2)] + [ep(4))]
|
||||||
{"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 1, "precision": "bf16"}, # full dp for moe and non-moe
|
{
|
||||||
{ # moe_dp = 2, non_moe_dp = 4
|
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"ep_size": 2,
|
"ep_size": 2,
|
||||||
"zero_stage": 1,
|
"zero_stage": 2,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
}, # moe_dp = 1, non_moe_dp = 4
|
}, # [dp(4)] + [ep(2) + moe_tp(2)]
|
||||||
{"tp_size": 1, "pp_size": 1, "ep_size": 4, "zero_stage": 1, "precision": "fp32"}, # full dp for non-moe and full ep for moe
|
{
|
||||||
{"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 0, "precision": "fp32"}, # full dp for moe and non-moe
|
"tp_size": 1,
|
||||||
|
"pp_size": 1,
|
||||||
|
"ep_size": 4,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp32"
|
||||||
|
}, # full dp for non-moe and full ep for moe
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 1,
|
||||||
|
"ep_size": 1,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp32"
|
||||||
|
}, # full dp for moe and non-moe
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_mixtral_test(test_config):
|
def run_mixtral_test(test_config):
|
||||||
|
|
Loading…
Reference in New Issue