[moe] init moe plugin comm setting with sp

colossalchat
hxwang 2024-07-18 08:37:06 +00:00 committed by Hongxin Liu
parent 09d6280d3e
commit 877d94bb8c
7 changed files with 101 additions and 95 deletions

View File

@ -1,6 +1,5 @@
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from copy import deepcopy
from types import MethodType from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple from typing import Callable, Optional, OrderedDict, Tuple
@ -106,37 +105,35 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None: def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
if "overlap_communication" not in kwargs: if "overlap_communication" not in kwargs:
kwargs["overlap_communication"] = False kwargs["overlap_communication"] = False # default by true in super class
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 self.ep_size = ep_size
self.moe_tp_size = moe_tp_size
self._init_moe_param_comm()
self.use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
if self.use_ddp: if self.use_ddp:
warnings.warn( warnings.warn(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated" f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
) )
self.ddp_config["find_unused_parameters"] = True self.ddp_config["find_unused_parameters"] = True
world_size = dist.get_world_size() if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size * self.sp_size)
self.ep_size = ep_size
self.moe_tp_size = moe_tp_size
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
raise ValueError( raise ValueError(
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}" f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to set ep_size=1 or zero_stage > 0"
) )
# self._init_moe_param_comm() # set ep_group after super().__init__()
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
# set ep_group after super init
# TODO do it in a better way # TODO do it in a better way
self.moe_dp_group = self.pp_group
self.ep_group = self.pp_group
self.moe_tp_group = self.pp_group
self.shard_config.ep_group = self.ep_group self.shard_config.ep_group = self.ep_group
self.shard_config.moe_dp_group = self.moe_dp_group self.shard_config.moe_dp_group = self.moe_dp_group
self.shard_config.moe_tp_group = self.moe_tp_group self.shard_config.moe_tp_group = self.moe_tp_group
@ -144,6 +141,35 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.force_overlap_comm = force_overlap_comm self.force_overlap_comm = force_overlap_comm
def _init_moe_param_comm(self): def _init_moe_param_comm(self):
world_size = dist.get_world_size()
if self.enable_sequence_parallelism:
# if sequence parallelism is enabled, we reuse the same group for ep and sp
if self.sequence_parallelism_mode == "all_to_all":
# when sequence parallelism is enabled, ep_group reuses sp_group
if self.ep_size != self.sp_size:
raise ValueError(
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} when sequence parallelism is enabled"
)
self.moe_dp_size = self.dp_size
self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters
self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.ep_group = self.sp_group
self.moe_tp_group = self.tp_group
else:
raise NotImplementedError(
f"sequence_parallelism_mode={self.sequence_parallelism_mode} is not supported"
)
else:
self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size)
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
raise ValueError(
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
)
self.moe_dp_group = None self.moe_dp_group = None
self.ep_group = None self.ep_group = None
self.moe_tp_group = None self.moe_tp_group = None
@ -195,7 +221,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
) )
self.logger.info( self.logger.info(
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size}\n"
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
ranks=[0], ranks=[0],
) )
@ -215,30 +242,18 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
param_info = get_param_info(optimizer) param_info = get_param_info(optimizer)
# TODO: Support Galore + ZeRO # TODO: Support Galore + ZeRO
self.zero_stage
deepcopy(self.zero_config)
# Replace with distributed implementation if exists # Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer) optimizer = cast_to_distributed(optimizer)
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
dp_group = self.dp_group
model = HybridParallelModule( model = HybridParallelModule(
module=model, module=model,
precision=self.precision, precision=self.precision,
shard_config=self.shard_config, shard_config=self.shard_config,
dp_group=dp_group, dp_group=self.dp_group,
tp_group=self.tp_group, tp_group=self.tp_group,
sp_group=self.sp_group, sp_group=self.sp_group,
use_ddp=use_ddp, use_ddp=self.use_ddp,
ddp_config=self.ddp_config, ddp_config=self.ddp_config,
custom_policy=self.custom_policy, custom_policy=self.custom_policy,
) )
@ -271,7 +286,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
) )
else: else:
if not (self.dp_size > 1 or self.moe_dp_size > 1): if self.dp_size <= 1:
warnings.warn( warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0." "If you do not intend to use cpu_offload, please consider set zero_stage=0."

View File

@ -10,13 +10,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import is_flash_attn_2_available, logging from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import ( from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
DPGradScalerIn,
DPGradScalerOut,
EPGradScalerIn,
EPGradScalerOut,
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig

View File

@ -118,7 +118,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
selected_experts_idx = selected_experts.argsort() selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts) input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
dist.get_rank()
output_split_sizes = torch.zeros_like(input_split_sizes) output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)

View File

@ -23,7 +23,7 @@ NUM_HEADS = 4
TOP_K = 1 TOP_K = 1
@parameterize("config", [(1, 1, 1)]) @parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]): def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config stage, ep_size, tp_size = config
dtype = torch.float16 dtype = torch.float16

View File

@ -24,11 +24,10 @@ NUM_HEADS = 4
TOP_K = 1 TOP_K = 1
@parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)]) @parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]): def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config stage, ep_size, tp_size = config
dtype = torch.float32 dtype, precision = torch.float16, "fp16"
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank()) torch.cuda.set_device(dist.get_rank())
@ -40,7 +39,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
zero_stage=stage, zero_stage=stage,
overlap_communication=False, overlap_communication=False,
initial_scale=1, initial_scale=1,
precision="fp32", precision=precision,
) )
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
@ -109,7 +108,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
dist.barrier() dist.barrier()
saved_model = MixtralModel.from_pretrained(model_dir).cuda() saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(torch_model, saved_model) check_model_equal(torch_model, saved_model)
dist.barrier() dist.barrier()

View File

@ -26,9 +26,7 @@ top_k = 2
def check_model_equal(model1, model2): def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())): for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
if loose_close(p1, p2, p1.dtype): loose_close(p1, p2, p1.dtype)
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
raise AssertionError(f"Model parameter {name} is not equal")
def get_optimizer_snapshot(optim): def get_optimizer_snapshot(optim):

View File

@ -141,12 +141,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[ [
# { # {
# "tp_size": 1, # "tp_size": 1,
# "pp_size": 2, # "pp_size": 1,
# "num_microbatches": 2, # "num_microbatches": 2,
# "ep_size": 2, # "ep_size": 2,
# "zero_stage": 1, # "zero_stage": 0,
# "overlap_communication": False, # "overlap_communication": False,
# "precision": "fp32", # "precision": "fp16",
# }, # [dp(4)] + [moe_dp(4)] # }, # [dp(4)] + [moe_dp(4)]
# { # {
# "tp_size": 1, # "tp_size": 1,
@ -169,7 +169,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{ # Ulysess + Flash attention { # Ulysess + Flash attention
"tp_size": 1, "tp_size": 1,
"pp_size": 1, "pp_size": 1,
"sp_size": 4, "sp_size": 2,
"ep_size": 1, "ep_size": 1,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all", "sequence_parallelism_mode": "all_to_all",