[Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
colossalchat
Haze188 2024-07-18 11:37:56 +08:00 committed by Hongxin Liu
parent 3e2b6132b7
commit 404b16faf3
6 changed files with 571 additions and 72 deletions

View File

@ -1,4 +1,6 @@
import warnings
from collections import defaultdict
from copy import deepcopy
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
@ -22,6 +24,8 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import cast_to_distributed
from colossalai.tensor.moe_tensor.api import is_moe_tensor
@ -114,21 +118,25 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.ddp_config["find_unused_parameters"] = True
world_size = dist.get_world_size()
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size)
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size * self.sp_size)
self.ep_size = ep_size
self.moe_tp_size = moe_tp_size
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size != world_size:
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
raise ValueError(
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
)
self._init_moe_param_comm()
# self._init_moe_param_comm()
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
# set ep_group after super init
# TODO do it in a better way
self.moe_dp_group = self.pp_group
self.ep_group = self.pp_group
self.moe_tp_group = self.pp_group
self.shard_config.ep_group = self.ep_group
self.shard_config.moe_dp_group = self.moe_dp_group
self.shard_config.moe_tp_group = self.moe_tp_group
@ -205,15 +213,32 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
# TODO: Support Galore + ZeRO
self.zero_stage
deepcopy(self.zero_config)
# Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
dp_group = self.dp_group
model = HybridParallelModule(
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_group,
dp_group=dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=self.use_ddp,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
@ -224,6 +249,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
reinitialize_optimizer(optimizer, model)
if self.zero_stage == 0:
is_zero = False
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
@ -236,7 +262,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
else:
optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
max_norm=self.max_norm,
pp_process_group=self.pp_group,
tp_process_group=self.tp_group,
)
else:
if not (self.dp_size > 1 or self.moe_dp_size > 1):
@ -244,6 +276,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
)
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
@ -262,4 +295,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
# Setup optimizers that require global states
optim = optimizer.optim
if isinstance(optim, DistributedOptim):
shard_to_param = optimizer.get_master_to_working_map() if is_zero else {}
padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int)
optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -209,7 +209,7 @@ class ProcessGroupMesh:
axis: Union[int, List[int]],
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
backend: Optional[str] = None,
return_ranks_by_group: bool = False
return_ranks_by_group: bool = False,
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
@ -257,7 +257,11 @@ class ProcessGroupMesh:
return target_group
def get_group_along_axis(
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None, return_ranks_by_group: bool = False
self,
axis: Union[int, List[int]],
indices_at_axis: Optional[List[int]] = None,
backend: Optional[str] = None,
return_ranks_by_group: bool = False,
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.

View File

@ -1,26 +1,47 @@
from typing import List, Optional
import inspect
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralSparseMoeBlock,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
apply_rotary_pos_emb,
load_balancing_loss_func,
repeat_kv,
)
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
if is_flash_attn_2_available():
from flash_attn import flash_attn_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, *args, **kwargs):
@ -97,6 +118,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
dist.get_rank()
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
@ -157,7 +179,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
class MixtralPipelineForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
This class serves as a micro library for forward function substitution of Mixtral models
under pipeline setting.
"""
@ -491,3 +513,335 @@ class MixtralPipelineForwards:
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
assert (sp_size is not None) and (
sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel"
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop("padding_mask")
bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
use_sliding_windows = (
_flash_supports_window_size
and getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
)
if not _flash_supports_window_size:
logger.warning_once(
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
" make sure to upgrade flash-attn library."
)
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
use_sliding_windows=use_sliding_windows,
)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
return forward
def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MoeModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
past_key_values_length = 0
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
output_router_logits,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if output_router_logits:
all_router_logits += (layer_outputs[-1],)
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
return forward

View File

@ -5,12 +5,17 @@ from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
from colossalai.shardformer.modeling.mixtral import (
EPMixtralSparseMoeBlock,
MixtralPipelineForwards,
get_mixtral_flash_attention_forward,
get_mixtral_flash_attention_model_forward,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
@ -21,27 +26,72 @@ class MixtralPolicy(Policy):
pass
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# non-moe params tensor parallelism
self.origin_attn_implement = self.model.config._attn_implementation
# if self.shard_config.enable_tensor_parallelism:
# # non-moe params tensor parallelism
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
# # Resize embedding
# vocab_size = self.model.config.vocab_size
# world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
# if vocab_size % world_size != 0:
# new_vocab_size = vocab_size + world_size - vocab_size % world_size
# self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralDecoderLayer,
MixtralFlashAttention2,
MixtralModel,
MixtralSdpaAttention,
)
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
ATTN_IMPLEMENTATION = {
"eager": MixtralAttention,
"flash_attention_2": MixtralFlashAttention2,
"sdpa": MixtralSdpaAttention,
}
policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
if sp_mode == "all_to_all":
decoder_attribute_replacement = {
"num_heads": self.model.config.num_attention_heads // sp_size,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
self.append_or_create_method_replacement(
description={
"forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement(
description={
"forward": get_mixtral_flash_attention_model_forward(
self.shard_config,
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
),
},
policy=policy,
target_key=MixtralModel,
)
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
@ -127,10 +177,12 @@ class MixtralPolicy(Policy):
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
@ -141,6 +193,7 @@ class MixtralPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
kwargs={"sp_partial_derived": sp_partial_derived},
),
policy=policy,
target_key=MixtralModel,
@ -308,5 +361,5 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama for sequence classification model"""
"""No shared params in mixtral for sequence classification model"""
return []

View File

@ -48,11 +48,13 @@ loss_fn = lambda x: x.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
config = MixtralConfig(
hidden_size=256,
intermediate_size=256,
num_attention_heads=64,
hidden_size=32,
intermediate_size=32,
num_attention_heads=8,
num_hidden_layers=2,
vocab_size=1000,
attn_implementation="flash_attention_2",
torch_dtype="float16",
output_router_logits=True,
)

View File

@ -3,6 +3,8 @@ import os
import pytest
import torch
import torch.distributed as dist
from torch.testing import assert_close
import colossalai
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
@ -15,6 +17,7 @@ from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
@ -27,13 +30,14 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# TODO: SGD failed for full dp
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
)
org_model = org_model.to(torch.float16)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
print(org_output.last_hidden_state.shape, sharded_output.last_hidden_state.shape)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
@ -45,6 +49,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol)
# unwrap model
mixtral_model = unwrap_model(org_model, "MixtralModel", "model")
@ -53,6 +58,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
col_layer_for_check = ["layers[0].self_attn.o_proj"]
# Check the grad when using ZeRO-1 and ZeRO-2
if (
# booster.plugin.zero_stage in [1, 2]
booster.plugin.shard_config.enable_sequence_parallelism
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
rank = dist.get_rank()
# for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
try:
assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
print(f"{rank=},passed grad: {n1}, {n2}")
except Exception as e:
print(f"{rank=},failed grad: {n1} {p1.grad[:2,:2]}, {n2} {p2.grad[:2, :2]}")
raise e
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
@ -84,28 +105,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
grads_to_check.update(row_layer_grads)
# check grads
# print(grads_to_check)
check_all_grad_tensors(grads_to_check)
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
try:
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
print(f"{rank=},passed param before step: {n1}, {n2}")
except Exception:
print(
f"{rank=},failed param before step: {n1} {p1[:2,:2] if p1 else None}, {n2} {p2[:2, :2] if p2 else None}"
)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
try:
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
print(f"{rank=},passed param after step: {n1}, {n2}")
except Exception as e:
print(
f"{rank=},failed param after step: {n1} {p1 if p1 is not None else None}, {n2} {p2 if p2 is not None else None}"
)
raise e
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 2e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(
mixtral_model,
shard_mixtral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
try:
check_weight(
mixtral_model,
shard_mixtral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
except Exception as e:
rank = dist.get_rank()
print(f"{rank=}, Failed config: {test_config}")
raise e
torch.cuda.empty_cache()
@ -113,33 +155,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 2,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",
}, # [dp(4)] + [moe_dp(4)]
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 2,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",
}, # [dp(2) + pp(2)] + [moe_pp(2)]
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 2,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",
}, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
# {
# "tp_size": 1,
# "pp_size": 2,
@ -148,7 +163,38 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# "zero_stage": 1,
# "overlap_communication": False,
# "precision": "fp32",
# }, # [dp(2) + pp(2)] + [ep(4))]
# }, # [dp(4)] + [moe_dp(4)]
# {
# "tp_size": 1,
# "pp_size": 2,
# "num_microbatches": 2,
# "ep_size": 2,
# "zero_stage": 1,
# "overlap_communication": False,
# "precision": "fp32",
# }, # [dp(2) + pp(2)] + [moe_pp(2)]
# {
# "tp_size": 2,
# "pp_size": 2,
# "num_microbatches": 2,
# "ep_size": 2,
# "zero_stage": 1,
# "overlap_communication": False,
# "precision": "fp32",
# }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 1,
"sp_size": 4,
"ep_size": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"zero_stage": 0,
"overlap_communication": False,
"precision": "fp16",
"initial_scale": 1,
"find_unused_parameters": True,
},
# {
# "tp_size": 1,
# "pp_size": 1,