diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index bf450534f..0ad3889ae 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,4 +1,6 @@ import warnings +from collections import defaultdict +from copy import deepcopy from types import MethodType from typing import Callable, Optional, OrderedDict, Tuple @@ -22,6 +24,8 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( ) from colossalai.checkpoint_io import MoECheckpointIO from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.interface.optimizer import DistributedOptim +from colossalai.nn.optimizer import cast_to_distributed from colossalai.tensor.moe_tensor.api import is_moe_tensor @@ -114,21 +118,25 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.ddp_config["find_unused_parameters"] = True world_size = dist.get_world_size() - self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size) + self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size * self.sp_size) self.ep_size = ep_size self.moe_tp_size = moe_tp_size - if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size != world_size: + if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size: raise ValueError( f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}" ) - self._init_moe_param_comm() + # self._init_moe_param_comm() self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0]) # set ep_group after super init # TODO do it in a better way + self.moe_dp_group = self.pp_group + self.ep_group = self.pp_group + self.moe_tp_group = self.pp_group + self.shard_config.ep_group = self.ep_group self.shard_config.moe_dp_group = self.moe_dp_group self.shard_config.moe_tp_group = self.moe_tp_group @@ -205,15 +213,32 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: param_info = get_param_info(optimizer) + + # TODO: Support Galore + ZeRO + self.zero_stage + deepcopy(self.zero_config) + # Replace with distributed implementation if exists + optimizer = cast_to_distributed(optimizer) + if not isinstance(model, ModelWrapper): + use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( + self.dp_size == 1 + and self.pp_size == 1 + and self.enable_sequence_parallelism + and self.sequence_parallelism_mode == "all_to_all" + ) + if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": + dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + else: + dp_group = self.dp_group model = HybridParallelModule( module=model, precision=self.precision, shard_config=self.shard_config, - dp_group=self.dp_group, + dp_group=dp_group, tp_group=self.tp_group, sp_group=self.sp_group, - use_ddp=self.use_ddp, + use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, ) @@ -224,6 +249,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): reinitialize_optimizer(optimizer, model) if self.zero_stage == 0: + is_zero = False if self.precision in ["fp16", "bf16"]: optimizer = HybridParallelAMPOptimizer( optimizer, @@ -236,7 +262,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) else: optimizer = HybridParallelNaiveOptimizer( - optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, ) else: if not (self.dp_size > 1 or self.moe_dp_size > 1): @@ -244,6 +276,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "If you do not intend to use cpu_offload, please consider set zero_stage=0." ) + assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = MoeHybridParallelZeroOptimizer( optimizer, model, @@ -262,4 +295,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): # inject update_master_params model.update_master_params = MethodType(optimizer.update_master_params, model) + # Setup optimizers that require global states + optim = optimizer.optim + if isinstance(optim, DistributedOptim): + shard_to_param = optimizer.get_master_to_working_map() if is_zero else {} + padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int) + optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero) + return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 66b77f7a2..a9d341efa 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -209,7 +209,7 @@ class ProcessGroupMesh: axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None, - return_ranks_by_group: bool = False + return_ranks_by_group: bool = False, ) -> Union[ProcessGroup, List[Tuple[int, ...]]]: """Create all process groups along the given axis, and return the one which the current process belongs to. @@ -257,7 +257,11 @@ class ProcessGroupMesh: return target_group def get_group_along_axis( - self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None, return_ranks_by_group: bool = False + self, + axis: Union[int, List[int]], + indices_at_axis: Optional[List[int]] = None, + backend: Optional[str] = None, + return_ranks_by_group: bool = False, ) -> Union[ProcessGroup, List[Tuple[int, ...]]]: """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index f8745c1d0..2b50f013d 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -1,26 +1,47 @@ -from typing import List, Optional +import inspect +import warnings +from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn.functional as F from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.models.mixtral.modeling_mixtral import ( MixtralSparseMoeBlock, MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + apply_rotary_pos_emb, load_balancing_loss_func, + repeat_kv, ) from transformers.utils import is_flash_attn_2_available, logging from colossalai.lazy import LazyInitContext from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard.utils import set_tensors_to_none from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func + + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): def __init__(self, *args, **kwargs): @@ -97,6 +118,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): selected_experts_idx = selected_experts.argsort() dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] input_split_sizes = selected_experts.bincount(minlength=self.num_experts) + dist.get_rank() output_split_sizes = torch.zeros_like(input_split_sizes) dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) @@ -157,7 +179,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): class MixtralPipelineForwards: """ - This class serves as a micro library for forward function substitution of Llama models + This class serves as a micro library for forward function substitution of Mixtral models under pipeline setting. """ @@ -491,3 +513,335 @@ class MixtralPipelineForwards: if output_router_logits: out["past_router_logits"] = outputs["past_router_logits"] return out + + +def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + logger = logging.get_logger(__name__) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + return forward + + +def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + logger = logging.get_logger(__name__) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = 0 + + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + return forward diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 8905b5696..10f54e1a4 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -5,12 +5,17 @@ from typing import Callable, Dict, List, Union import torch.nn as nn from torch import Tensor from torch.nn import Module -from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D from colossalai.shardformer.layer.linear import Linear1D_Row -from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards +from colossalai.shardformer.modeling.mixtral import ( + EPMixtralSparseMoeBlock, + MixtralPipelineForwards, + get_mixtral_flash_attention_forward, + get_mixtral_flash_attention_model_forward, +) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"] @@ -21,27 +26,72 @@ class MixtralPolicy(Policy): pass def preprocess(self): - if self.shard_config.enable_tensor_parallelism: - # non-moe params tensor parallelism + self.origin_attn_implement = self.model.config._attn_implementation + # if self.shard_config.enable_tensor_parallelism: + # # non-moe params tensor parallelism - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size + # # Resize embedding + # vocab_size = self.model.config.vocab_size + # world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + # if vocab_size % world_size != 0: + # new_vocab_size = vocab_size + world_size - vocab_size % world_size + # self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - policy = {} + from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralDecoderLayer, + MixtralFlashAttention2, + MixtralModel, + MixtralSdpaAttention, + ) - if self.shard_config.enable_sequence_parallelism: - self.shard_config.enable_sequence_parallelism = False - raise NotImplementedError( - "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + ATTN_IMPLEMENTATION = { + "eager": MixtralAttention, + "flash_attention_2": MixtralFlashAttention2, + "sdpa": MixtralSdpaAttention, + } + policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + + sp_mode = self.shard_config.sequence_parallelism_mode or None + sp_size = self.shard_config.sequence_parallel_size or None + sp_group = self.shard_config.sequence_parallel_process_group or None + sp_partial_derived = sp_mode in ["split_gather", "ring"] + if sp_mode == "all_to_all": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + + policy[attn_cls] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, ) + if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: + self.append_or_create_method_replacement( + description={ + "forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=attn_cls, + ) + if self.pipeline_stage_manager is None: + self.append_or_create_method_replacement( + description={ + "forward": get_mixtral_flash_attention_model_forward( + self.shard_config, + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key=MixtralModel, + ) embedding_cls = None if self.shard_config.enable_tensor_parallelism: @@ -127,10 +177,12 @@ class MixtralPolicy(Policy): SubModuleReplacementDescription( suffix="input_layernorm", target_module=FusedRMSNorm, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=FusedRMSNorm, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, @@ -141,6 +193,7 @@ class MixtralPolicy(Policy): description=SubModuleReplacementDescription( suffix="norm", target_module=FusedRMSNorm, + kwargs={"sp_partial_derived": sp_partial_derived}, ), policy=policy, target_key=MixtralModel, @@ -308,5 +361,5 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy): return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in llama for sequence classification model""" + """No shared params in mixtral for sequence classification model""" return [] diff --git a/tests/kit/model_zoo/transformers/mixtral.py b/tests/kit/model_zoo/transformers/mixtral.py index 7fa4ff335..40e5a7b02 100644 --- a/tests/kit/model_zoo/transformers/mixtral.py +++ b/tests/kit/model_zoo/transformers/mixtral.py @@ -48,11 +48,13 @@ loss_fn = lambda x: x.loss loss_fn_for_seq_classification = lambda output: output.logits.mean() config = MixtralConfig( - hidden_size=256, - intermediate_size=256, - num_attention_heads=64, + hidden_size=32, + intermediate_size=32, + num_attention_heads=8, num_hidden_layers=2, vocab_size=1000, + attn_implementation="flash_attention_2", + torch_dtype="float16", output_router_logits=True, ) diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index f268d1686..2e2b675a4 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -3,6 +3,8 @@ import os import pytest import torch +import torch.distributed as dist +from torch.testing import assert_close import colossalai from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin @@ -15,6 +17,7 @@ from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, check_all_grad_tensors, check_loss, + check_output_hidden_state, check_weight, get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, @@ -27,13 +30,14 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): # TODO: SGD failed for full dp org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( - model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam + model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD ) + org_model = org_model.to(torch.float16) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster ) - + print(org_output.last_hidden_state.shape, sharded_output.last_hidden_state.shape) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group @@ -45,6 +49,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol) # unwrap model mixtral_model = unwrap_model(org_model, "MixtralModel", "model") @@ -53,6 +58,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] col_layer_for_check = ["layers[0].self_attn.o_proj"] + # Check the grad when using ZeRO-1 and ZeRO-2 + if ( + # booster.plugin.zero_stage in [1, 2] + booster.plugin.shard_config.enable_sequence_parallelism + and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" + ): + rank = dist.get_rank() + # for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): + for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()): + try: + assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False) + print(f"{rank=},passed grad: {n1}, {n2}") + except Exception as e: + print(f"{rank=},failed grad: {n1} {p1.grad[:2,:2]}, {n2} {p2.grad[:2, :2]}") + raise e + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: @@ -84,28 +105,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, grads_to_check.update(row_layer_grads) # check grads + # print(grads_to_check) check_all_grad_tensors(grads_to_check) - + for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()): + try: + assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False) + print(f"{rank=},passed param before step: {n1}, {n2}") + except Exception: + print( + f"{rank=},failed param before step: {n1} {p1[:2,:2] if p1 else None}, {n2} {p2[:2, :2] if p2 else None}" + ) # optimizer executes step org_optimizer.step() sharded_optimizer.step() - + for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()): + try: + assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False) + print(f"{rank=},passed param after step: {n1}, {n2}") + except Exception as e: + print( + f"{rank=},failed param after step: {n1} {p1 if p1 is not None else None}, {n2} {p2 if p2 is not None else None}" + ) + raise e # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config["precision"] == "fp32": atol, rtol = 2e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight( - mixtral_model, - shard_mixtral_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False, - ) + try: + check_weight( + mixtral_model, + shard_mixtral_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + except Exception as e: + rank = dist.get_rank() + print(f"{rank=}, Failed config: {test_config}") + raise e torch.cuda.empty_cache() @@ -113,33 +155,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "ep_size": 2, - "zero_stage": 1, - "overlap_communication": False, - "precision": "fp32", - }, # [dp(4)] + [moe_dp(4)] - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "ep_size": 2, - "zero_stage": 1, - "overlap_communication": False, - "precision": "fp32", - }, # [dp(2) + pp(2)] + [moe_pp(2)] - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 2, - "ep_size": 2, - "zero_stage": 1, - "overlap_communication": False, - "precision": "fp32", - }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass # { # "tp_size": 1, # "pp_size": 2, @@ -148,7 +163,38 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # "zero_stage": 1, # "overlap_communication": False, # "precision": "fp32", - # }, # [dp(2) + pp(2)] + [ep(4))] + # }, # [dp(4)] + [moe_dp(4)] + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 2, + # "ep_size": 2, + # "zero_stage": 1, + # "overlap_communication": False, + # "precision": "fp32", + # }, # [dp(2) + pp(2)] + [moe_pp(2)] + # { + # "tp_size": 2, + # "pp_size": 2, + # "num_microbatches": 2, + # "ep_size": 2, + # "zero_stage": 1, + # "overlap_communication": False, + # "precision": "fp32", + # }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 1, + "sp_size": 4, + "ep_size": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "zero_stage": 0, + "overlap_communication": False, + "precision": "fp16", + "initial_scale": 1, + "find_unused_parameters": True, + }, # { # "tp_size": 1, # "pp_size": 1,