Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

858 lines
38 KiB

import math
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
try:
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
apply_rotary_pos_emb,
repeat_kv,
)
except ImportError:
Qwen2Model = "Qwen2Model"
Qwen2ForCausalLM = "Qwen2ForCausalLM"
Qwen2Attention = "Qwen2Attention"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, dist_cross_entropy
class Qwen2PipelineForwards:
"""
This class serves as a micro library for forward function substitution of Qwen2 models
under pipeline setting.
"""
@staticmethod
def qwen2_model_forward(
self: Qwen2Model,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
seq_length_with_past = seq_length
past_key_values_length = 0
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
# assert past_key_values is None, "past_key_values is not supported for Qwen2 models at the moment."
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else:
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
num_ckpt_layers = end_idx - start_idx
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
)
assert num_ckpt_layers <= end_idx - start_idx
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
# always return dict for imediate stage
return {"hidden_states": hidden_states}
@staticmethod
def qwen2_for_causal_lm_forward(
self: Qwen2ForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = Qwen2PipelineForwards.qwen2_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def qwen2_for_sequence_classification_forward(
self: Qwen2ForSequenceClassification,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
transformer_outputs = Qwen2PipelineForwards.qwen2_model_forward(
self.model,
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
if input_ids is not None:
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
batch_size = inputs_embeds.shape[0]
else:
batch_size = hidden_states.shape[0]
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if self.config.pad_token_id is None and batch_size != 1:
print(self.config.pad_token_id)
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward(
self: Qwen2Attention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
assert (sp_size is not None) and (
sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel"
bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return forward
def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
hidden_states = inputs_embeds
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
if sp_mode in ["ring", "split_gather"]:
hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
def forward(
self: Qwen2ForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return forward