import inspect import warnings from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn.functional as F from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) from transformers.models.mixtral.modeling_mixtral import ( MixtralSparseMoeBlock, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, apply_rotary_pos_emb, load_balancing_loss_func, repeat_kv, ) from transformers.utils import is_flash_attn_2_available, logging from colossalai.lazy import LazyInitContext from colossalai.moe._operation import ( DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, ) from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import ( all_to_all_comm, gather_forward_split_backward, split_forward_gather_backward, ) from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard.utils import set_tensors_to_none from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group if is_flash_attn_2_available(): from flash_attn import flash_attn_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): def __init__(self, *args, **kwargs): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup): assert tp_group is not None assert moe_dp_group is not None assert ep_group is not None # setup ep group self.ep_size = dist.get_world_size(ep_group) self.ep_rank = dist.get_rank(ep_group) self.ep_group = ep_group if self.num_experts % self.ep_size != 0: raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") self.num_experts_per_ep = self.num_experts // self.ep_size self.expert_start_idx = self.ep_rank * self.num_experts_per_ep held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] set_tensors_to_none(self.experts, exclude=set(held_experts)) # setup moe_dp group self.moe_dp_group = moe_dp_group self.moe_dp_size = moe_dp_group.size() # setup global tp group self.tp_group = tp_group if self.tp_group.size() > 1: for expert in held_experts: expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group) expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group) expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group) for p in self.experts.parameters(): set_moe_tensor_ep_group(p, ep_group) @staticmethod def from_native_module( module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, *args, **kwargs, ) -> "EPMixtralSparseMoeBlock": # TODO: better init LazyInitContext.materialize(module) module.__class__ = EPMixtralSparseMoeBlock module.setup_process_groups(tp_group, moe_dp_group, ep_group) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) selected_experts = selected_experts.t().reshape(-1) selected_experts_idx = selected_experts.argsort() dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] input_split_sizes = selected_experts.bincount(minlength=self.num_experts) output_split_sizes = torch.zeros_like(input_split_sizes) dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) with torch.no_grad(): activate_experts = output_split_sizes[: self.num_experts_per_ep].clone() for i in range(1, self.ep_size): activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep] activate_experts = (activate_experts > 0).float() dist.all_reduce(activate_experts, group=self.moe_dp_group) input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) # compute expert output output_states = EPGradScalerIn.apply(output_states, self.ep_size) if output_states.size(0) > 0: if self.num_experts_per_ep == 1: # no need to split expert = self.experts[self.expert_start_idx] output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0]) output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states) output_states = expert.w2(output_states) output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0]) else: output_states_splits = output_states.split(output_split_sizes.tolist()) output_states_list = [] for i, split_states in enumerate(output_states_splits): if split_states.size(0) == 0: continue expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] split_states = DPGradScalerIn.apply( split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep] ) split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) split_states = expert.w2(split_states) split_states = DPGradScalerOut.apply( split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep] ) output_states_list.append(split_states) output_states = torch.cat(output_states_list) output_states = EPGradScalerOut.apply(output_states, self.ep_size) dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) recover_experts_idx = torch.empty_like(selected_experts_idx) recover_experts_idx[selected_experts_idx] = torch.arange( selected_experts_idx.size(0), device=selected_experts_idx.device ) dispatch_states = dispatch_states[recover_experts_idx] k_hidden_states = dispatch_states.chunk(self.top_k) output_states = k_hidden_states[0] * routing_weights[:, 0, None] for i in range(1, self.top_k): output_states += k_hidden_states[i] * routing_weights[:, i, None] output_states = output_states.reshape(batch_size, sequence_length, hidden_dim) return output_states, router_logits class MixtralPipelineForwards: """ This class serves as a micro library for forward function substitution of Mixtral models under pipeline setting. """ @staticmethod def mixtral_model_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, past_router_logits: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, ): r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, MixtralForCausalLM >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if stage_manager.is_first_stage(): # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds else: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape device = hidden_states.device seq_length_with_past = seq_length past_key_values_length = 0 # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if is_flash_attn_2_available(): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length, sliding_window=self.config.sliding_window, ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None next_decoder_cache = None start_idx, end_idx = stage_index[0], stage_index[1] for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids, None, output_attentions, output_router_logits, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, output_router_logits, use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) if output_router_logits: all_router_logits += (layer_outputs[-1],) if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if output_router_logits and past_router_logits is not None: all_router_logits = past_router_logits + all_router_logits if stage_manager.is_last_stage(): if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) else: if output_router_logits: return { "hidden_states": hidden_states, "past_router_logits": all_router_logits, } else: return { "hidden_states": hidden_states, } @staticmethod def mixtral_for_causal_lm_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, past_router_logits: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, ): r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, MixtralForCausalLM >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = MixtralPipelineForwards.mixtral_model_forward( self.model, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, past_router_logits=past_router_logits, ) past_key_values = None if stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) if labels is not None: loss += self.router_aux_loss_coef * aux_loss if not return_dict: output = (logits,) + outputs[1:] if output_router_logits: output = (aux_loss,) + output return (loss,) + output if loss is not None else output return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=None, hidden_states=outputs[0], attentions=None, router_logits=outputs[-1], ) else: out = {} hidden_states = outputs.get("hidden_states") out["hidden_states"] = hidden_states if output_router_logits: out["past_router_logits"] = outputs["past_router_logits"] return out def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: if sp_mode is not None: assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" assert (sp_size is not None) and ( sp_group is not None ), "Must specify sp_size and sp_group for sequence parallel" if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) # overwrite attention_mask with padding_mask attention_mask = kwargs.pop("padding_mask") bsz, q_len, _ = hidden_states.size() # sp: modify sp_len when sequence parallel mode is ring if sp_mode in ["split_gather", "ring"]: q_len *= sp_size query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": query_states = all_to_all_comm(query_states, sp_group) key_states = all_to_all_comm(key_states, sp_group) value_states = all_to_all_comm(value_states, sp_group) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) use_sliding_windows = ( _flash_supports_window_size and getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window ) if not _flash_supports_window_size: logger.warning_once( "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" " make sure to upgrade flash-attn library." ) if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window and cache_has_contents ): slicing_tokens = 1 - self.config.sliding_window past_key = past_key_value[self.layer_idx][0] past_value = past_key_value[self.layer_idx][1] past_key = past_key[:, :, slicing_tokens:, :].contiguous() past_value = past_value[:, :, slicing_tokens:, :].contiguous() if past_key.shape[-2] != self.config.sliding_window - 1: raise ValueError( f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" f" {past_key.shape}" ) if attention_mask is not None: attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) # Reashape to the expected shape for Flash Attention query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, use_sliding_windows=use_sliding_windows, ) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value return forward def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MoeModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") past_key_values_length = 0 if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, sliding_window=self.config.sliding_window, ) if sp_mode in ["ring", "split_gather"]: inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, output_router_logits, use_cache, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) if output_router_logits: all_router_logits += (layer_outputs[-1],) hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) elif sp_mode == "all_to_all": hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) return forward