mirror of https://github.com/hpcaitech/ColossalAI
[chore] arg pass & remove drop token
parent
8dbb86899d
commit
102b784a10
|
@ -1,8 +1,8 @@
|
||||||
import warnings
|
import warnings
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Callable, Optional, OrderedDict, Tuple
|
from typing import Callable, Optional, OrderedDict, Tuple
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
@ -21,7 +21,6 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||||
reinitialize_optimizer,
|
reinitialize_optimizer,
|
||||||
)
|
)
|
||||||
from colossalai.checkpoint_io import MoECheckpointIO
|
from colossalai.checkpoint_io import MoECheckpointIO
|
||||||
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||||
|
|
||||||
|
@ -89,11 +88,9 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
||||||
overlap_communication=overlap_communication,
|
overlap_communication=overlap_communication,
|
||||||
partition_grad=partition_grad,
|
partition_grad=partition_grad,
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
# dp_process_group=dp_process_group,
|
|
||||||
tp_process_group=tp_process_group,
|
tp_process_group=tp_process_group,
|
||||||
pp_process_group=pp_process_group,
|
pp_process_group=pp_process_group,
|
||||||
forced_dtype=forced_dtype,
|
forced_dtype=forced_dtype,
|
||||||
## moe args
|
|
||||||
pg_to_param_list=pg_param_list,
|
pg_to_param_list=pg_param_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -104,6 +101,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
|
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
|
||||||
|
if "overlap_communication" not in kwargs:
|
||||||
|
kwargs["overlap_communication"] = False
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||||
|
@ -182,7 +182,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
assert self.moe_tp_group is None
|
assert self.moe_tp_group is None
|
||||||
self.moe_tp_group = group
|
self.moe_tp_group = group
|
||||||
|
|
||||||
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", ranks=[0])
|
self.logger.info(
|
||||||
|
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
|
||||||
|
ranks=[0],
|
||||||
|
)
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||||
return MoECheckpointIO(
|
return MoECheckpointIO(
|
||||||
|
|
|
@ -14,7 +14,13 @@ from transformers.models.mixtral.modeling_mixtral import (
|
||||||
from transformers.utils import is_flash_attn_2_available, logging
|
from transformers.utils import is_flash_attn_2_available, logging
|
||||||
|
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.moe._operation import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, drop_tokens, gather_tokens
|
from colossalai.moe._operation import (
|
||||||
|
DPGradScalerIn,
|
||||||
|
DPGradScalerOut,
|
||||||
|
EPGradScalerIn,
|
||||||
|
EPGradScalerOut,
|
||||||
|
all_to_all_uneven,
|
||||||
|
)
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||||
|
@ -25,7 +31,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||||
|
|
||||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup):
|
def setup_process_groups(
|
||||||
|
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
|
||||||
|
):
|
||||||
assert tp_group is not None
|
assert tp_group is not None
|
||||||
assert moe_dp_group is not None
|
assert moe_dp_group is not None
|
||||||
assert ep_group is not None
|
assert ep_group is not None
|
||||||
|
@ -59,7 +67,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(
|
def from_native_module(
|
||||||
module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup, *args, **kwargs
|
module: MixtralSparseMoeBlock,
|
||||||
|
tp_group: ProcessGroup,
|
||||||
|
moe_dp_group: ProcessGroup,
|
||||||
|
ep_group: ProcessGroup,
|
||||||
|
moe_tp_group: ProcessGroup,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> "EPMixtralSparseMoeBlock":
|
) -> "EPMixtralSparseMoeBlock":
|
||||||
# TODO: better init
|
# TODO: better init
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
|
@ -96,8 +110,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||||
|
|
||||||
if self.tp_group.size() > 1:
|
# TODO drop tokens to reduce tp group redundant communication
|
||||||
dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
|
|
||||||
|
|
||||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||||
# compute expert output
|
# compute expert output
|
||||||
|
@ -116,20 +129,21 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
for i, split_states in enumerate(output_states_splits):
|
for i, split_states in enumerate(output_states_splits):
|
||||||
if split_states.size(0) == 0:
|
if split_states.size(0) == 0:
|
||||||
continue
|
continue
|
||||||
split_states = DPGradScalerIn.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
|
split_states = DPGradScalerIn.apply(
|
||||||
|
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
|
||||||
|
)
|
||||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||||
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
||||||
split_states = expert.w2(split_states)
|
split_states = expert.w2(split_states)
|
||||||
split_states = DPGradScalerOut.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
|
split_states = DPGradScalerOut.apply(
|
||||||
|
split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
|
||||||
|
)
|
||||||
output_states_list.append(split_states)
|
output_states_list.append(split_states)
|
||||||
output_states = torch.cat(output_states_list)
|
output_states = torch.cat(output_states_list)
|
||||||
|
|
||||||
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
||||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||||
|
|
||||||
if self.tp_group.size() > 1:
|
|
||||||
dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
|
|
||||||
|
|
||||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||||
selected_experts_idx.size(0), device=selected_experts_idx.device
|
selected_experts_idx.size(0), device=selected_experts_idx.device
|
||||||
|
|
|
@ -36,7 +36,13 @@ def check_mixtral_moe_layer():
|
||||||
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
||||||
orig_output, orig_logits = orig_model(x)
|
orig_output, orig_logits = orig_model(x)
|
||||||
model = deepcopy(orig_model)
|
model = deepcopy(orig_model)
|
||||||
model = EPMixtralSparseMoeBlock.from_native_module(model, ep_group=plugin.ep_group)
|
model = EPMixtralSparseMoeBlock.from_native_module(
|
||||||
|
model,
|
||||||
|
ep_group=plugin.ep_group,
|
||||||
|
tp_group=plugin.tp_group,
|
||||||
|
moe_dp_group=plugin.moe_dp_group,
|
||||||
|
moe_tp_group=plugin.moe_tp_group,
|
||||||
|
)
|
||||||
ep_output, ep_logits = model(x)
|
ep_output, ep_logits = model(x)
|
||||||
assert_close(orig_logits, ep_logits)
|
assert_close(orig_logits, ep_logits)
|
||||||
assert_close(orig_output, ep_output)
|
assert_close(orig_output, ep_output)
|
||||||
|
|
|
@ -13,7 +13,6 @@ from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
from colossalai.checkpoint_io import MoECheckpointIO
|
|
||||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||||
from colossalai.testing import parameterize, spawn
|
from colossalai.testing import parameterize, spawn
|
||||||
from colossalai.testing.utils import spawn
|
from colossalai.testing.utils import spawn
|
||||||
|
@ -115,7 +114,6 @@ def check_moe_checkpoint(test_config):
|
||||||
pp_size=2,
|
pp_size=2,
|
||||||
ep_size=2,
|
ep_size=2,
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
checkpoint_io=MoECheckpointIO,
|
|
||||||
microbatch_size=1,
|
microbatch_size=1,
|
||||||
zero_stage=1,
|
zero_stage=1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -34,7 +34,7 @@ def split_grad(grad, world_size):
|
||||||
@parameterize("stage", [1])
|
@parameterize("stage", [1])
|
||||||
@parameterize("ep_size", [1, 2, 4])
|
@parameterize("ep_size", [1, 2, 4])
|
||||||
@parameterize("tp_size", [1, 2, 4])
|
@parameterize("tp_size", [1, 2, 4])
|
||||||
def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int = 1):
|
def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
|
|
Loading…
Reference in New Issue