[chore] arg pass & remove drop token

colossalchat
hxwang 2024-07-12 09:08:16 +00:00 committed by Hongxin Liu
parent 8dbb86899d
commit 102b784a10
5 changed files with 41 additions and 20 deletions

View File

@ -1,8 +1,8 @@
import warnings import warnings
from types import MethodType from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple from typing import Callable, Optional, OrderedDict, Tuple
import numpy as np
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -21,7 +21,6 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
reinitialize_optimizer, reinitialize_optimizer,
) )
from colossalai.checkpoint_io import MoECheckpointIO from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.tensor.moe_tensor.api import is_moe_tensor
@ -89,11 +88,9 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
overlap_communication=overlap_communication, overlap_communication=overlap_communication,
partition_grad=partition_grad, partition_grad=partition_grad,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
# dp_process_group=dp_process_group,
tp_process_group=tp_process_group, tp_process_group=tp_process_group,
pp_process_group=pp_process_group, pp_process_group=pp_process_group,
forced_dtype=forced_dtype, forced_dtype=forced_dtype,
## moe args
pg_to_param_list=pg_param_list, pg_to_param_list=pg_param_list,
) )
@ -104,6 +101,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
""" """
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None: def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
if "overlap_communication" not in kwargs:
kwargs["overlap_communication"] = False
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
@ -182,7 +182,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
assert self.moe_tp_group is None assert self.moe_tp_group is None
self.moe_tp_group = group self.moe_tp_group = group
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", ranks=[0]) self.logger.info(
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
ranks=[0],
)
def get_checkpoint_io(self) -> MoECheckpointIO: def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO( return MoECheckpointIO(

View File

@ -14,7 +14,13 @@ from transformers.models.mixtral.modeling_mixtral import (
from transformers.utils import is_flash_attn_2_available, logging from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, drop_tokens, gather_tokens from colossalai.moe._operation import (
DPGradScalerIn,
DPGradScalerOut,
EPGradScalerIn,
EPGradScalerOut,
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none from colossalai.shardformer.shard.utils import set_tensors_to_none
@ -25,7 +31,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup): def setup_process_groups(
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
):
assert tp_group is not None assert tp_group is not None
assert moe_dp_group is not None assert moe_dp_group is not None
assert ep_group is not None assert ep_group is not None
@ -59,7 +67,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
@staticmethod @staticmethod
def from_native_module( def from_native_module(
module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup, *args, **kwargs module: MixtralSparseMoeBlock,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_tp_group: ProcessGroup,
*args,
**kwargs,
) -> "EPMixtralSparseMoeBlock": ) -> "EPMixtralSparseMoeBlock":
# TODO: better init # TODO: better init
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
@ -96,8 +110,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
if self.tp_group.size() > 1: # TODO drop tokens to reduce tp group redundant communication
dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output # compute expert output
@ -116,20 +129,21 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
for i, split_states in enumerate(output_states_splits): for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0: if split_states.size(0) == 0:
continue continue
split_states = DPGradScalerIn.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()) split_states = DPGradScalerIn.apply(
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
)
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states) split_states = expert.w2(split_states)
split_states = DPGradScalerOut.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()) split_states = DPGradScalerOut.apply(
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
)
output_states_list.append(split_states) output_states_list.append(split_states)
output_states = torch.cat(output_states_list) output_states = torch.cat(output_states_list)
output_states = EPGradScalerOut.apply(output_states, self.ep_size) output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
if self.tp_group.size() > 1:
dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
recover_experts_idx = torch.empty_like(selected_experts_idx) recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange( recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device selected_experts_idx.size(0), device=selected_experts_idx.device

View File

@ -36,7 +36,13 @@ def check_mixtral_moe_layer():
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda() x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output, orig_logits = orig_model(x) orig_output, orig_logits = orig_model(x)
model = deepcopy(orig_model) model = deepcopy(orig_model)
model = EPMixtralSparseMoeBlock.from_native_module(model, ep_group=plugin.ep_group) model = EPMixtralSparseMoeBlock.from_native_module(
model,
ep_group=plugin.ep_group,
tp_group=plugin.tp_group,
moe_dp_group=plugin.moe_dp_group,
moe_tp_group=plugin.moe_tp_group,
)
ep_output, ep_logits = model(x) ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits) assert_close(orig_logits, ep_logits)
assert_close(orig_output, ep_output) assert_close(orig_output, ep_output)

View File

@ -13,7 +13,6 @@ from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import parameterize, spawn from colossalai.testing import parameterize, spawn
from colossalai.testing.utils import spawn from colossalai.testing.utils import spawn
@ -115,7 +114,6 @@ def check_moe_checkpoint(test_config):
pp_size=2, pp_size=2,
ep_size=2, ep_size=2,
tp_size=1, tp_size=1,
checkpoint_io=MoECheckpointIO,
microbatch_size=1, microbatch_size=1,
zero_stage=1, zero_stage=1,
) )

View File

@ -34,7 +34,7 @@ def split_grad(grad, world_size):
@parameterize("stage", [1]) @parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4]) @parameterize("ep_size", [1, 2, 4])
@parameterize("tp_size", [1, 2, 4]) @parameterize("tp_size", [1, 2, 4])
def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int = 1): def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
dtype = torch.bfloat16 dtype = torch.bfloat16
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()