mirror of https://github.com/hpcaitech/ColossalAI
[moe] init moe plugin comm setting with sp
parent
09d6280d3e
commit
877d94bb8c
|
@ -1,6 +1,5 @@
|
|||
import warnings
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, OrderedDict, Tuple
|
||||
|
||||
|
@ -106,37 +105,35 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
|
||||
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
|
||||
if "overlap_communication" not in kwargs:
|
||||
kwargs["overlap_communication"] = False
|
||||
kwargs["overlap_communication"] = False # default by true in super class
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
self.ep_size = ep_size
|
||||
self.moe_tp_size = moe_tp_size
|
||||
|
||||
self._init_moe_param_comm()
|
||||
|
||||
self.use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||
self.dp_size == 1
|
||||
and self.pp_size == 1
|
||||
and self.enable_sequence_parallelism
|
||||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
)
|
||||
|
||||
if self.use_ddp:
|
||||
warnings.warn(
|
||||
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
|
||||
)
|
||||
self.ddp_config["find_unused_parameters"] = True
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size * self.sp_size)
|
||||
self.ep_size = ep_size
|
||||
self.moe_tp_size = moe_tp_size
|
||||
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
||||
raise ValueError(
|
||||
f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to set ep_size=1 or zero_stage > 0"
|
||||
)
|
||||
|
||||
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
|
||||
raise ValueError(
|
||||
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
|
||||
)
|
||||
|
||||
# self._init_moe_param_comm()
|
||||
|
||||
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
|
||||
|
||||
# set ep_group after super init
|
||||
# set ep_group after super().__init__()
|
||||
# TODO do it in a better way
|
||||
self.moe_dp_group = self.pp_group
|
||||
self.ep_group = self.pp_group
|
||||
self.moe_tp_group = self.pp_group
|
||||
|
||||
self.shard_config.ep_group = self.ep_group
|
||||
self.shard_config.moe_dp_group = self.moe_dp_group
|
||||
self.shard_config.moe_tp_group = self.moe_tp_group
|
||||
|
@ -144,48 +141,77 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.force_overlap_comm = force_overlap_comm
|
||||
|
||||
def _init_moe_param_comm(self):
|
||||
self.moe_dp_group = None
|
||||
self.ep_group = None
|
||||
self.moe_tp_group = None
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# create submesh for ep, moe_dp, moe_tp
|
||||
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
|
||||
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
|
||||
)
|
||||
if self.enable_sequence_parallelism:
|
||||
# if sequence parallelism is enabled, we reuse the same group for ep and sp
|
||||
if self.sequence_parallelism_mode == "all_to_all":
|
||||
# when sequence parallelism is enabled, ep_group reuses sp_group
|
||||
if self.ep_size != self.sp_size:
|
||||
raise ValueError(
|
||||
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} when sequence parallelism is enabled"
|
||||
)
|
||||
|
||||
global_rank = self.pg_mesh.rank
|
||||
pp_rank = self.pg_mesh.coordinate(self.pp_axis)
|
||||
self.moe_dp_size = self.dp_size
|
||||
self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters
|
||||
self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
self.ep_group = self.sp_group
|
||||
self.moe_tp_group = self.tp_group
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"sequence_parallelism_mode={self.sequence_parallelism_mode} is not supported"
|
||||
)
|
||||
|
||||
# create groups from submesh
|
||||
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
|
||||
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
|
||||
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
||||
else:
|
||||
self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size)
|
||||
|
||||
# hardcode here since we only have 3 axis
|
||||
# moe_dp_group
|
||||
for ep_idx in range(self.ep_size):
|
||||
for moe_tp_idx in range(self.moe_tp_size):
|
||||
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
|
||||
group = dist.new_group(moe_dp_ranks)
|
||||
if pp_rank == stage_idx and global_rank in moe_dp_ranks:
|
||||
assert self.moe_dp_group is None
|
||||
self.moe_dp_group = group
|
||||
# ep_group
|
||||
for moe_dp_idx in range(self.moe_dp_size):
|
||||
for moe_tp_idx in range(self.moe_tp_size):
|
||||
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
|
||||
group = dist.new_group(ep_ranks)
|
||||
if pp_rank == stage_idx and global_rank in ep_ranks:
|
||||
assert self.ep_group is None
|
||||
self.ep_group = group
|
||||
# moe_tp_group
|
||||
for moe_dp_idx in range(self.moe_dp_size):
|
||||
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
|
||||
raise ValueError(
|
||||
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
|
||||
)
|
||||
|
||||
self.moe_dp_group = None
|
||||
self.ep_group = None
|
||||
self.moe_tp_group = None
|
||||
|
||||
# create submesh for ep, moe_dp, moe_tp
|
||||
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
|
||||
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
|
||||
)
|
||||
|
||||
global_rank = self.pg_mesh.rank
|
||||
pp_rank = self.pg_mesh.coordinate(self.pp_axis)
|
||||
|
||||
# create groups from submesh
|
||||
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
|
||||
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
|
||||
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
||||
|
||||
# hardcode here since we only have 3 axis
|
||||
# moe_dp_group
|
||||
for ep_idx in range(self.ep_size):
|
||||
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
|
||||
group = dist.new_group(moe_tp_ranks)
|
||||
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
|
||||
assert self.moe_tp_group is None
|
||||
self.moe_tp_group = group
|
||||
for moe_tp_idx in range(self.moe_tp_size):
|
||||
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
|
||||
group = dist.new_group(moe_dp_ranks)
|
||||
if pp_rank == stage_idx and global_rank in moe_dp_ranks:
|
||||
assert self.moe_dp_group is None
|
||||
self.moe_dp_group = group
|
||||
# ep_group
|
||||
for moe_dp_idx in range(self.moe_dp_size):
|
||||
for moe_tp_idx in range(self.moe_tp_size):
|
||||
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
|
||||
group = dist.new_group(ep_ranks)
|
||||
if pp_rank == stage_idx and global_rank in ep_ranks:
|
||||
assert self.ep_group is None
|
||||
self.ep_group = group
|
||||
# moe_tp_group
|
||||
for moe_dp_idx in range(self.moe_dp_size):
|
||||
for ep_idx in range(self.ep_size):
|
||||
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
|
||||
group = dist.new_group(moe_tp_ranks)
|
||||
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
|
||||
assert self.moe_tp_group is None
|
||||
self.moe_tp_group = group
|
||||
|
||||
if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
|
||||
# NOTE: different tp settings between moe and non moe param require complex comm logic, where all_to_all might not be suitable
|
||||
|
@ -195,7 +221,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
)
|
||||
|
||||
self.logger.info(
|
||||
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
|
||||
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size}\n"
|
||||
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
|
@ -215,30 +242,18 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
param_info = get_param_info(optimizer)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
self.zero_stage
|
||||
deepcopy(self.zero_config)
|
||||
# Replace with distributed implementation if exists
|
||||
optimizer = cast_to_distributed(optimizer)
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||
self.dp_size == 1
|
||||
and self.pp_size == 1
|
||||
and self.enable_sequence_parallelism
|
||||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
)
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
else:
|
||||
dp_group = self.dp_group
|
||||
model = HybridParallelModule(
|
||||
module=model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=dp_group,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
sp_group=self.sp_group,
|
||||
use_ddp=use_ddp,
|
||||
use_ddp=self.use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
)
|
||||
|
@ -271,7 +286,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
tp_process_group=self.tp_group,
|
||||
)
|
||||
else:
|
||||
if not (self.dp_size > 1 or self.moe_dp_size > 1):
|
||||
if self.dp_size <= 1:
|
||||
warnings.warn(
|
||||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
|
||||
|
|
|
@ -10,13 +10,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
|||
from transformers.utils import is_flash_attn_2_available, logging
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe._operation import (
|
||||
DPGradScalerIn,
|
||||
DPGradScalerOut,
|
||||
EPGradScalerIn,
|
||||
EPGradScalerOut,
|
||||
all_to_all_uneven,
|
||||
)
|
||||
from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
|
|
@ -118,7 +118,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
selected_experts_idx = selected_experts.argsort()
|
||||
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
|
||||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||
dist.get_rank()
|
||||
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ NUM_HEADS = 4
|
|||
TOP_K = 1
|
||||
|
||||
|
||||
@parameterize("config", [(1, 1, 1)])
|
||||
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
stage, ep_size, tp_size = config
|
||||
dtype = torch.float16
|
||||
|
|
|
@ -24,11 +24,10 @@ NUM_HEADS = 4
|
|||
TOP_K = 1
|
||||
|
||||
|
||||
@parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)])
|
||||
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
stage, ep_size, tp_size = config
|
||||
dtype = torch.float32
|
||||
|
||||
dtype, precision = torch.float16, "fp16"
|
||||
rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
|
@ -40,7 +39,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
|||
zero_stage=stage,
|
||||
overlap_communication=False,
|
||||
initial_scale=1,
|
||||
precision="fp32",
|
||||
precision=precision,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
|
@ -109,7 +108,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
|||
|
||||
dist.barrier()
|
||||
|
||||
saved_model = MixtralModel.from_pretrained(model_dir).cuda()
|
||||
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
|
||||
check_model_equal(torch_model, saved_model)
|
||||
|
||||
dist.barrier()
|
||||
|
|
|
@ -26,9 +26,7 @@ top_k = 2
|
|||
def check_model_equal(model1, model2):
|
||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
|
||||
if loose_close(p1, p2, p1.dtype):
|
||||
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
|
||||
raise AssertionError(f"Model parameter {name} is not equal")
|
||||
loose_close(p1, p2, p1.dtype)
|
||||
|
||||
|
||||
def get_optimizer_snapshot(optim):
|
||||
|
|
|
@ -141,12 +141,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
[
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 2,
|
||||
# "pp_size": 1,
|
||||
# "num_microbatches": 2,
|
||||
# "ep_size": 2,
|
||||
# "zero_stage": 1,
|
||||
# "zero_stage": 0,
|
||||
# "overlap_communication": False,
|
||||
# "precision": "fp32",
|
||||
# "precision": "fp16",
|
||||
# }, # [dp(4)] + [moe_dp(4)]
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
|
@ -169,7 +169,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"sp_size": 4,
|
||||
"sp_size": 2,
|
||||
"ep_size": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
|
|
Loading…
Reference in New Issue