[chore] arg pass & remove drop token

colossalchat
hxwang 2024-07-12 09:08:16 +00:00 committed by Hongxin Liu
parent 8dbb86899d
commit 102b784a10
5 changed files with 41 additions and 20 deletions

View File

@ -1,8 +1,8 @@
import warnings
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
import numpy as np
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@ -21,7 +21,6 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
reinitialize_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.moe_tensor.api import is_moe_tensor
@ -89,11 +88,9 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
# dp_process_group=dp_process_group,
tp_process_group=tp_process_group,
pp_process_group=pp_process_group,
forced_dtype=forced_dtype,
## moe args
pg_to_param_list=pg_param_list,
)
@ -104,6 +101,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
if "overlap_communication" not in kwargs:
kwargs["overlap_communication"] = False
super().__init__(*args, **kwargs)
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
@ -182,7 +182,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
assert self.moe_tp_group is None
self.moe_tp_group = group
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", ranks=[0])
self.logger.info(
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
ranks=[0],
)
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(

View File

@ -14,7 +14,13 @@ from transformers.models.mixtral.modeling_mixtral import (
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, drop_tokens, gather_tokens
from colossalai.moe._operation import (
DPGradScalerIn,
DPGradScalerOut,
EPGradScalerIn,
EPGradScalerOut,
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
@ -25,7 +31,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup):
def setup_process_groups(
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
@ -59,7 +67,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
@staticmethod
def from_native_module(
module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup, *args, **kwargs
module: MixtralSparseMoeBlock,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_tp_group: ProcessGroup,
*args,
**kwargs,
) -> "EPMixtralSparseMoeBlock":
# TODO: better init
LazyInitContext.materialize(module)
@ -96,8 +110,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
if self.tp_group.size() > 1:
dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
# TODO drop tokens to reduce tp group redundant communication
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
@ -116,20 +129,21 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
split_states = DPGradScalerIn.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
split_states = DPGradScalerIn.apply(
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
)
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
split_states = DPGradScalerOut.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
split_states = DPGradScalerOut.apply(
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
)
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
if self.tp_group.size() > 1:
dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device

View File

@ -36,7 +36,13 @@ def check_mixtral_moe_layer():
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output, orig_logits = orig_model(x)
model = deepcopy(orig_model)
model = EPMixtralSparseMoeBlock.from_native_module(model, ep_group=plugin.ep_group)
model = EPMixtralSparseMoeBlock.from_native_module(
model,
ep_group=plugin.ep_group,
tp_group=plugin.tp_group,
moe_dp_group=plugin.moe_dp_group,
moe_tp_group=plugin.moe_tp_group,
)
ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits)
assert_close(orig_output, ep_output)

View File

@ -13,7 +13,6 @@ from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import parameterize, spawn
from colossalai.testing.utils import spawn
@ -115,7 +114,6 @@ def check_moe_checkpoint(test_config):
pp_size=2,
ep_size=2,
tp_size=1,
checkpoint_io=MoECheckpointIO,
microbatch_size=1,
zero_stage=1,
)

View File

@ -34,7 +34,7 @@ def split_grad(grad, world_size):
@parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4])
@parameterize("tp_size", [1, 2, 4])
def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int = 1):
def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
dtype = torch.bfloat16
rank = torch.distributed.get_rank()