mirror of https://github.com/hpcaitech/ColossalAI
[chore] arg pass & remove drop token
parent
8dbb86899d
commit
102b784a10
|
@ -1,8 +1,8 @@
|
|||
import warnings
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, OrderedDict, Tuple
|
||||
import numpy as np
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
@ -21,7 +21,6 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
|||
reinitialize_optimizer,
|
||||
)
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
|
@ -89,11 +88,9 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
|||
overlap_communication=overlap_communication,
|
||||
partition_grad=partition_grad,
|
||||
cpu_offload=cpu_offload,
|
||||
# dp_process_group=dp_process_group,
|
||||
tp_process_group=tp_process_group,
|
||||
pp_process_group=pp_process_group,
|
||||
forced_dtype=forced_dtype,
|
||||
## moe args
|
||||
pg_to_param_list=pg_param_list,
|
||||
)
|
||||
|
||||
|
@ -104,6 +101,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
"""
|
||||
|
||||
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
|
||||
if "overlap_communication" not in kwargs:
|
||||
kwargs["overlap_communication"] = False
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
|
@ -142,7 +142,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.moe_dp_group = None
|
||||
self.ep_group = None
|
||||
self.moe_tp_group = None
|
||||
|
||||
|
||||
# create submesh for ep, moe_dp, moe_tp
|
||||
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
|
||||
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
|
||||
|
@ -182,7 +182,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
assert self.moe_tp_group is None
|
||||
self.moe_tp_group = group
|
||||
|
||||
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", ranks=[0])
|
||||
self.logger.info(
|
||||
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||
return MoECheckpointIO(
|
||||
|
|
|
@ -14,7 +14,13 @@ from transformers.models.mixtral.modeling_mixtral import (
|
|||
from transformers.utils import is_flash_attn_2_available, logging
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe._operation import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, drop_tokens, gather_tokens
|
||||
from colossalai.moe._operation import (
|
||||
DPGradScalerIn,
|
||||
DPGradScalerOut,
|
||||
EPGradScalerIn,
|
||||
EPGradScalerOut,
|
||||
all_to_all_uneven,
|
||||
)
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
|
@ -25,7 +31,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup):
|
||||
def setup_process_groups(
|
||||
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
|
||||
):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
|
@ -59,7 +67,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup, *args, **kwargs
|
||||
module: MixtralSparseMoeBlock,
|
||||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
moe_tp_group: ProcessGroup,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> "EPMixtralSparseMoeBlock":
|
||||
# TODO: better init
|
||||
LazyInitContext.materialize(module)
|
||||
|
@ -96,8 +110,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
|
||||
if self.tp_group.size() > 1:
|
||||
dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
|
||||
# TODO drop tokens to reduce tp group redundant communication
|
||||
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
# compute expert output
|
||||
|
@ -116,20 +129,21 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
for i, split_states in enumerate(output_states_splits):
|
||||
if split_states.size(0) == 0:
|
||||
continue
|
||||
split_states = DPGradScalerIn.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
|
||||
split_states = DPGradScalerIn.apply(
|
||||
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
|
||||
)
|
||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
||||
split_states = expert.w2(split_states)
|
||||
split_states = DPGradScalerOut.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
|
||||
split_states = DPGradScalerOut.apply(
|
||||
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
|
||||
)
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
|
||||
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
|
||||
if self.tp_group.size() > 1:
|
||||
dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
|
||||
|
||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||
selected_experts_idx.size(0), device=selected_experts_idx.device
|
||||
|
|
|
@ -36,7 +36,13 @@ def check_mixtral_moe_layer():
|
|||
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
||||
orig_output, orig_logits = orig_model(x)
|
||||
model = deepcopy(orig_model)
|
||||
model = EPMixtralSparseMoeBlock.from_native_module(model, ep_group=plugin.ep_group)
|
||||
model = EPMixtralSparseMoeBlock.from_native_module(
|
||||
model,
|
||||
ep_group=plugin.ep_group,
|
||||
tp_group=plugin.tp_group,
|
||||
moe_dp_group=plugin.moe_dp_group,
|
||||
moe_tp_group=plugin.moe_tp_group,
|
||||
)
|
||||
ep_output, ep_logits = model(x)
|
||||
assert_close(orig_logits, ep_logits)
|
||||
assert_close(orig_output, ep_output)
|
||||
|
|
|
@ -13,7 +13,6 @@ from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
|||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.testing import parameterize, spawn
|
||||
from colossalai.testing.utils import spawn
|
||||
|
@ -115,7 +114,6 @@ def check_moe_checkpoint(test_config):
|
|||
pp_size=2,
|
||||
ep_size=2,
|
||||
tp_size=1,
|
||||
checkpoint_io=MoECheckpointIO,
|
||||
microbatch_size=1,
|
||||
zero_stage=1,
|
||||
)
|
||||
|
|
|
@ -34,7 +34,7 @@ def split_grad(grad, world_size):
|
|||
@parameterize("stage", [1])
|
||||
@parameterize("ep_size", [1, 2, 4])
|
||||
@parameterize("tp_size", [1, 2, 4])
|
||||
def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int = 1):
|
||||
def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
|
|
Loading…
Reference in New Issue