mirror of https://github.com/hpcaitech/ColossalAI
[Feature] MoE Ulysses Support (#5918)
* moe sp support * moe sp bug solve * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>colossalchat
parent
3e2b6132b7
commit
404b16faf3
|
@ -1,4 +1,6 @@
|
|||
import warnings
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, OrderedDict, Tuple
|
||||
|
||||
|
@ -22,6 +24,8 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
|||
)
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.nn.optimizer import cast_to_distributed
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
|
||||
|
@ -114,21 +118,25 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.ddp_config["find_unused_parameters"] = True
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size)
|
||||
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size * self.sp_size)
|
||||
self.ep_size = ep_size
|
||||
self.moe_tp_size = moe_tp_size
|
||||
|
||||
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size != world_size:
|
||||
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
|
||||
raise ValueError(
|
||||
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
|
||||
)
|
||||
|
||||
self._init_moe_param_comm()
|
||||
# self._init_moe_param_comm()
|
||||
|
||||
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
|
||||
|
||||
# set ep_group after super init
|
||||
# TODO do it in a better way
|
||||
self.moe_dp_group = self.pp_group
|
||||
self.ep_group = self.pp_group
|
||||
self.moe_tp_group = self.pp_group
|
||||
|
||||
self.shard_config.ep_group = self.ep_group
|
||||
self.shard_config.moe_dp_group = self.moe_dp_group
|
||||
self.shard_config.moe_tp_group = self.moe_tp_group
|
||||
|
@ -205,15 +213,32 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
param_info = get_param_info(optimizer)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
self.zero_stage
|
||||
deepcopy(self.zero_config)
|
||||
# Replace with distributed implementation if exists
|
||||
optimizer = cast_to_distributed(optimizer)
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||
self.dp_size == 1
|
||||
and self.pp_size == 1
|
||||
and self.enable_sequence_parallelism
|
||||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
)
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
else:
|
||||
dp_group = self.dp_group
|
||||
model = HybridParallelModule(
|
||||
module=model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=self.dp_group,
|
||||
dp_group=dp_group,
|
||||
tp_group=self.tp_group,
|
||||
sp_group=self.sp_group,
|
||||
use_ddp=self.use_ddp,
|
||||
use_ddp=use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
)
|
||||
|
@ -224,6 +249,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
reinitialize_optimizer(optimizer, model)
|
||||
|
||||
if self.zero_stage == 0:
|
||||
is_zero = False
|
||||
if self.precision in ["fp16", "bf16"]:
|
||||
optimizer = HybridParallelAMPOptimizer(
|
||||
optimizer,
|
||||
|
@ -236,7 +262,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(
|
||||
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
max_norm=self.max_norm,
|
||||
pp_process_group=self.pp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
)
|
||||
else:
|
||||
if not (self.dp_size > 1 or self.moe_dp_size > 1):
|
||||
|
@ -244,6 +276,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
|
||||
)
|
||||
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||
optimizer = MoeHybridParallelZeroOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
|
@ -262,4 +295,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
# inject update_master_params
|
||||
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||
|
||||
# Setup optimizers that require global states
|
||||
optim = optimizer.optim
|
||||
if isinstance(optim, DistributedOptim):
|
||||
shard_to_param = optimizer.get_master_to_working_map() if is_zero else {}
|
||||
padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int)
|
||||
optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
|
|
@ -209,7 +209,7 @@ class ProcessGroupMesh:
|
|||
axis: Union[int, List[int]],
|
||||
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
backend: Optional[str] = None,
|
||||
return_ranks_by_group: bool = False
|
||||
return_ranks_by_group: bool = False,
|
||||
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
|
||||
"""Create all process groups along the given axis, and return the one which the current process belongs to.
|
||||
|
||||
|
@ -257,7 +257,11 @@ class ProcessGroupMesh:
|
|||
return target_group
|
||||
|
||||
def get_group_along_axis(
|
||||
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None, return_ranks_by_group: bool = False
|
||||
self,
|
||||
axis: Union[int, List[int]],
|
||||
indices_at_axis: Optional[List[int]] = None,
|
||||
backend: Optional[str] = None,
|
||||
return_ranks_by_group: bool = False,
|
||||
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
|
||||
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
|
||||
|
||||
|
|
|
@ -1,26 +1,47 @@
|
|||
from typing import List, Optional
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.models.mixtral.modeling_mixtral import (
|
||||
MixtralSparseMoeBlock,
|
||||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
apply_rotary_pos_emb,
|
||||
load_balancing_loss_func,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils import is_flash_attn_2_available, logging
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
||||
|
||||
|
||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -97,6 +118,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
selected_experts_idx = selected_experts.argsort()
|
||||
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
|
||||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||
dist.get_rank()
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
|
@ -157,7 +179,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
|
||||
class MixtralPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
This class serves as a micro library for forward function substitution of Mixtral models
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
|
@ -491,3 +513,335 @@ class MixtralPipelineForwards:
|
|||
if output_router_logits:
|
||||
out["past_router_logits"] = outputs["past_router_logits"]
|
||||
return out
|
||||
|
||||
|
||||
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
if sp_mode is not None:
|
||||
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
||||
assert (sp_size is not None) and (
|
||||
sp_group is not None
|
||||
), "Must specify sp_size and sp_group for sequence parallel"
|
||||
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
|
||||
# overwrite attention_mask with padding_mask
|
||||
attention_mask = kwargs.pop("padding_mask")
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# sp: modify sp_len when sequence parallel mode is ring
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
q_len *= sp_size
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
use_sliding_windows = (
|
||||
_flash_supports_window_size
|
||||
and getattr(self.config, "sliding_window", None) is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
)
|
||||
if not _flash_supports_window_size:
|
||||
logger.warning_once(
|
||||
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
|
||||
" make sure to upgrade flash-attn library."
|
||||
)
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
||||
if (
|
||||
getattr(self.config, "sliding_window", None) is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
and cache_has_contents
|
||||
):
|
||||
slicing_tokens = 1 - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[self.layer_idx][0]
|
||||
past_value = past_key_value[self.layer_idx][1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
|
||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||
raise ValueError(
|
||||
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, slicing_tokens:]
|
||||
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
||||
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
# Reashape to the expected shape for Flash Attention
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
use_sliding_windows=use_sliding_windows,
|
||||
)
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, MoeModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_router_logits = (
|
||||
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_router_logits = () if output_router_logits else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if output_router_logits:
|
||||
all_router_logits += (layer_outputs[-1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
if v is not None
|
||||
)
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
|
|
@ -5,12 +5,17 @@ from typing import Callable, Dict, List, Union
|
|||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Row
|
||||
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
|
||||
from colossalai.shardformer.modeling.mixtral import (
|
||||
EPMixtralSparseMoeBlock,
|
||||
MixtralPipelineForwards,
|
||||
get_mixtral_flash_attention_forward,
|
||||
get_mixtral_flash_attention_model_forward,
|
||||
)
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
|
||||
|
@ -21,27 +26,72 @@ class MixtralPolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# non-moe params tensor parallelism
|
||||
self.origin_attn_implement = self.model.config._attn_implementation
|
||||
# if self.shard_config.enable_tensor_parallelism:
|
||||
# # non-moe params tensor parallelism
|
||||
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
# # Resize embedding
|
||||
# vocab_size = self.model.config.vocab_size
|
||||
# world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
# if vocab_size % world_size != 0:
|
||||
# new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
# self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
policy = {}
|
||||
from transformers.models.mixtral.modeling_mixtral import (
|
||||
MixtralAttention,
|
||||
MixtralDecoderLayer,
|
||||
MixtralFlashAttention2,
|
||||
MixtralModel,
|
||||
MixtralSdpaAttention,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
raise NotImplementedError(
|
||||
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": MixtralAttention,
|
||||
"flash_attention_2": MixtralFlashAttention2,
|
||||
"sdpa": MixtralSdpaAttention,
|
||||
}
|
||||
policy = {}
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
if sp_mode == "all_to_all":
|
||||
decoder_attribute_replacement = {
|
||||
"num_heads": self.model.config.num_attention_heads // sp_size,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
||||
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_mixtral_flash_attention_model_forward(
|
||||
self.shard_config,
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=MixtralModel,
|
||||
)
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
@ -127,10 +177,12 @@ class MixtralPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
|
@ -141,6 +193,7 @@ class MixtralPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=MixtralModel,
|
||||
|
@ -308,5 +361,5 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
|||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama for sequence classification model"""
|
||||
"""No shared params in mixtral for sequence classification model"""
|
||||
return []
|
||||
|
|
|
@ -48,11 +48,13 @@ loss_fn = lambda x: x.loss
|
|||
loss_fn_for_seq_classification = lambda output: output.logits.mean()
|
||||
|
||||
config = MixtralConfig(
|
||||
hidden_size=256,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=64,
|
||||
hidden_size=32,
|
||||
intermediate_size=32,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=2,
|
||||
vocab_size=1000,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype="float16",
|
||||
output_router_logits=True,
|
||||
)
|
||||
|
||||
|
|
|
@ -3,6 +3,8 @@ import os
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
|
@ -15,6 +17,7 @@ from tests.test_shardformer.test_model._utils import (
|
|||
build_model_from_hybrid_plugin,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
|
@ -27,13 +30,14 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
|||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
# TODO: SGD failed for full dp
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
|
||||
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
|
||||
)
|
||||
|
||||
org_model = org_model.to(torch.float16)
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
)
|
||||
|
||||
print(org_output.last_hidden_state.shape, sharded_output.last_hidden_state.shape)
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
|
@ -45,6 +49,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol)
|
||||
|
||||
# unwrap model
|
||||
mixtral_model = unwrap_model(org_model, "MixtralModel", "model")
|
||||
|
@ -53,6 +58,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
|
||||
col_layer_for_check = ["layers[0].self_attn.o_proj"]
|
||||
|
||||
# Check the grad when using ZeRO-1 and ZeRO-2
|
||||
if (
|
||||
# booster.plugin.zero_stage in [1, 2]
|
||||
booster.plugin.shard_config.enable_sequence_parallelism
|
||||
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||
):
|
||||
rank = dist.get_rank()
|
||||
# for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
||||
try:
|
||||
assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
print(f"{rank=},passed grad: {n1}, {n2}")
|
||||
except Exception as e:
|
||||
print(f"{rank=},failed grad: {n1} {p1.grad[:2,:2]}, {n2} {p2.grad[:2, :2]}")
|
||||
raise e
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
|
@ -84,28 +105,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# check grads
|
||||
# print(grads_to_check)
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
||||
try:
|
||||
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
print(f"{rank=},passed param before step: {n1}, {n2}")
|
||||
except Exception:
|
||||
print(
|
||||
f"{rank=},failed param before step: {n1} {p1[:2,:2] if p1 else None}, {n2} {p2[:2, :2] if p2 else None}"
|
||||
)
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
||||
try:
|
||||
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
print(f"{rank=},passed param after step: {n1}, {n2}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f"{rank=},failed param after step: {n1} {p1 if p1 is not None else None}, {n2} {p2 if p2 is not None else None}"
|
||||
)
|
||||
raise e
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 2e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_weight(
|
||||
mixtral_model,
|
||||
shard_mixtral_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
try:
|
||||
check_weight(
|
||||
mixtral_model,
|
||||
shard_mixtral_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
except Exception as e:
|
||||
rank = dist.get_rank()
|
||||
print(f"{rank=}, Failed config: {test_config}")
|
||||
raise e
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -113,33 +155,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"ep_size": 2,
|
||||
"zero_stage": 1,
|
||||
"overlap_communication": False,
|
||||
"precision": "fp32",
|
||||
}, # [dp(4)] + [moe_dp(4)]
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"ep_size": 2,
|
||||
"zero_stage": 1,
|
||||
"overlap_communication": False,
|
||||
"precision": "fp32",
|
||||
}, # [dp(2) + pp(2)] + [moe_pp(2)]
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"ep_size": 2,
|
||||
"zero_stage": 1,
|
||||
"overlap_communication": False,
|
||||
"precision": "fp32",
|
||||
}, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 2,
|
||||
|
@ -148,7 +163,38 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
# "zero_stage": 1,
|
||||
# "overlap_communication": False,
|
||||
# "precision": "fp32",
|
||||
# }, # [dp(2) + pp(2)] + [ep(4))]
|
||||
# }, # [dp(4)] + [moe_dp(4)]
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 2,
|
||||
# "num_microbatches": 2,
|
||||
# "ep_size": 2,
|
||||
# "zero_stage": 1,
|
||||
# "overlap_communication": False,
|
||||
# "precision": "fp32",
|
||||
# }, # [dp(2) + pp(2)] + [moe_pp(2)]
|
||||
# {
|
||||
# "tp_size": 2,
|
||||
# "pp_size": 2,
|
||||
# "num_microbatches": 2,
|
||||
# "ep_size": 2,
|
||||
# "zero_stage": 1,
|
||||
# "overlap_communication": False,
|
||||
# "precision": "fp32",
|
||||
# }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"sp_size": 4,
|
||||
"ep_size": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"zero_stage": 0,
|
||||
"overlap_communication": False,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"find_unused_parameters": True,
|
||||
},
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 1,
|
||||
|
|
Loading…
Reference in New Issue