|
|
|
import random
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from transformers.modeling_outputs import (
|
|
|
|
BaseModelOutputWithPast,
|
|
|
|
CausalLMOutputWithPast,
|
|
|
|
QuestionAnsweringModelOutput,
|
|
|
|
SequenceClassifierOutputWithPast,
|
|
|
|
)
|
|
|
|
from transformers.models.opt.modeling_opt import (
|
|
|
|
OPTForCausalLM,
|
|
|
|
OPTForQuestionAnswering,
|
|
|
|
OPTForSequenceClassification,
|
|
|
|
OPTModel,
|
|
|
|
)
|
|
|
|
from transformers.utils import logging
|
|
|
|
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
|
|
|
|
|
|
|
|
class OPTPipelineForwards:
|
|
|
|
'''
|
|
|
|
This class serves as a micro library for forward function substitution of OPT models
|
|
|
|
under pipeline setting.
|
|
|
|
'''
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length):
|
|
|
|
# create causal mask
|
|
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
|
|
from transformers.models.opt.modeling_opt import _make_causal_mask
|
|
|
|
combined_attention_mask = None
|
|
|
|
if input_shape[-1] > 1:
|
|
|
|
combined_attention_mask = _make_causal_mask(
|
|
|
|
input_shape,
|
|
|
|
_dtype,
|
|
|
|
device,
|
|
|
|
past_key_values_length=past_key_values_length,
|
|
|
|
)
|
|
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
|
|
expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype,
|
|
|
|
tgt_len=input_shape[-1]).to(device)
|
|
|
|
combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
|
|
|
|
combined_attention_mask)
|
|
|
|
|
|
|
|
return combined_attention_mask
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
|
|
|
"""
|
|
|
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
|
|
"""
|
|
|
|
bsz, src_len = mask.size()
|
|
|
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
|
|
|
|
|
|
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
|
|
|
|
|
|
|
inverted_mask = 1.0 - expanded_mask
|
|
|
|
|
|
|
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def opt_model_forward(
|
|
|
|
self: OPTModel,
|
|
|
|
input_ids: torch.LongTensor = None,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
|
|
stage_index: Optional[List[int]] = None,
|
|
|
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
|
|
'''
|
|
|
|
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
|
|
|
|
'''
|
|
|
|
|
|
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
|
|
from transformers.utils import logging
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
output_hidden_states = (output_hidden_states
|
|
|
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
|
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
decoder = self.decoder
|
|
|
|
if stage_manager.is_first_stage():
|
|
|
|
# retrieve input_ids and inputs_embeds
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
|
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
|
|
|
elif input_ids is not None:
|
|
|
|
input_shape = input_ids.size()
|
|
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
|
|
elif inputs_embeds is not None:
|
|
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
else:
|
|
|
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
|
|
|
|
|
|
|
batch_size, seq_length = input_shape
|
|
|
|
|
|
|
|
if inputs_embeds is None:
|
|
|
|
inputs_embeds = decoder.embed_tokens(input_ids)
|
|
|
|
|
|
|
|
if decoder.project_in is not None:
|
|
|
|
inputs_embeds = decoder.project_in(inputs_embeds)
|
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
_dtype = inputs_embeds.dtype
|
|
|
|
|
|
|
|
else:
|
|
|
|
if hidden_states is None:
|
|
|
|
raise ValueError("hidden_states shouln't be None for intermediate stages.")
|
|
|
|
input_shape = hidden_states.size()[:-1]
|
|
|
|
batch_size, seq_length = input_shape[0], input_shape[1]
|
|
|
|
device = hidden_states.device
|
|
|
|
_dtype = hidden_states.dtype
|
|
|
|
|
|
|
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
|
|
# required mask seq length can be calculated via length of past
|
|
|
|
mask_seq_length = past_key_values_length + seq_length
|
|
|
|
# embed positions
|
|
|
|
if attention_mask is None:
|
|
|
|
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
|
|
|
|
elif attention_mask.shape[1] != mask_seq_length:
|
|
|
|
raise ValueError(
|
|
|
|
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
|
|
|
f"{mask_seq_length} (sum of the lengths of current and past inputs)")
|
|
|
|
|
|
|
|
causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype,
|
|
|
|
device, past_key_values_length)
|
|
|
|
|
|
|
|
if stage_manager.is_first_stage():
|
|
|
|
pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length)
|
|
|
|
hidden_states = inputs_embeds + pos_embeds
|
|
|
|
|
|
|
|
if decoder.gradient_checkpointing and decoder.training:
|
|
|
|
if use_cache:
|
|
|
|
logger.warning_once(
|
|
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
|
|
|
use_cache = False
|
|
|
|
|
|
|
|
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
|
|
|
if past_key_values:
|
|
|
|
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
|
|
|
past_key_values = None
|
|
|
|
if output_attentions:
|
|
|
|
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
|
|
|
output_attentions = False
|
|
|
|
if output_hidden_states:
|
|
|
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
|
|
|
output_hidden_states = False
|
|
|
|
if use_cache:
|
|
|
|
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
|
|
|
use_cache = False
|
|
|
|
|
|
|
|
# decoder layers
|
|
|
|
all_hidden_states = () if output_hidden_states else None
|
|
|
|
all_self_attns = () if output_attentions else None
|
|
|
|
next_decoder_cache = () if use_cache else None
|
|
|
|
|
|
|
|
# check if head_mask has a correct number of layers specified if desired
|
|
|
|
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
|
|
|
|
if attn_mask is not None:
|
|
|
|
if attn_mask.size()[0] != (len(decoder.layers)):
|
|
|
|
raise ValueError(
|
|
|
|
f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for"
|
|
|
|
f" {head_mask.size()[0]}.")
|
|
|
|
|
|
|
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
|
|
|
|
|
|
|
torch.cuda.set_device(device)
|
|
|
|
|
|
|
|
for idx in range(start_idx, end_idx):
|
|
|
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
|
|
|
decoder_layer = decoder.layers[idx]
|
|
|
|
|
|
|
|
if output_hidden_states:
|
|
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
|
|
|
|
dropout_probability = random.uniform(0, 1)
|
|
|
|
if decoder.training and (dropout_probability < decoder.layerdrop):
|
|
|
|
continue
|
|
|
|
|
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
|
|
|
|
|
|
if decoder.gradient_checkpointing and decoder.training:
|
|
|
|
|
|
|
|
def create_custom_forward(module):
|
|
|
|
|
|
|
|
def custom_forward(*inputs):
|
|
|
|
# None for past_key_value
|
|
|
|
return module(*inputs, output_attentions, None)
|
|
|
|
|
|
|
|
return custom_forward
|
|
|
|
|
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
|
|
create_custom_forward(decoder_layer),
|
|
|
|
hidden_states,
|
|
|
|
causal_attention_mask,
|
|
|
|
head_mask[idx] if head_mask is not None else None,
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
layer_outputs = decoder_layer(
|
|
|
|
hidden_states,
|
|
|
|
attention_mask=causal_attention_mask,
|
|
|
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
|
|
past_key_value=past_key_value,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
use_cache=use_cache,
|
|
|
|
)
|
|
|
|
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
|
|
|
|
if use_cache:
|
|
|
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
|
|
|
|
|
|
|
if output_attentions:
|
|
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
if decoder.final_layer_norm is not None:
|
|
|
|
hidden_states = decoder.final_layer_norm(hidden_states)
|
|
|
|
if decoder.project_out is not None:
|
|
|
|
hidden_states = decoder.project_out(hidden_states)
|
|
|
|
|
|
|
|
# add hidden states from the last decoder layer
|
|
|
|
if output_hidden_states:
|
|
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
|
|
|
|
next_cache = next_decoder_cache if use_cache else None
|
|
|
|
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
if not return_dict:
|
|
|
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
|
|
|
|
|
|
return BaseModelOutputWithPast(
|
|
|
|
last_hidden_state=hidden_states,
|
|
|
|
past_key_values=next_cache,
|
|
|
|
hidden_states=all_hidden_states,
|
|
|
|
attentions=all_self_attns,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return {'hidden_states': hidden_states}
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def opt_for_causal_lm_forward(
|
|
|
|
self: OPTForCausalLM,
|
|
|
|
input_ids: torch.LongTensor = None,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
labels: Optional[torch.LongTensor] = None,
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
|
|
stage_index: Optional[List[int]] = None,
|
|
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
|
|
r"""
|
|
|
|
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward.
|
|
|
|
Please refer to original code of transformers for more details.
|
|
|
|
"""
|
|
|
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
output_hidden_states = (output_hidden_states
|
|
|
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
|
|
outputs = OPTPipelineForwards.opt_model_forward(
|
|
|
|
self.model,
|
|
|
|
input_ids=input_ids,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
head_mask=head_mask,
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
use_cache=use_cache,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
return_dict=return_dict,
|
|
|
|
stage_manager=stage_manager,
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
stage_index=stage_index,
|
|
|
|
)
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
logits = self.lm_head(outputs[0]).contiguous()
|
|
|
|
loss = None
|
|
|
|
if labels is not None:
|
|
|
|
# move labels to correct device to enable model parallelism
|
|
|
|
labels = labels.to(logits.device)
|
|
|
|
# Shift so that tokens < n predict n
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
# Flatten the tokens
|
|
|
|
loss_fct = CrossEntropyLoss()
|
|
|
|
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
|
|
|
if not return_dict:
|
|
|
|
output = (logits,) + outputs[1:]
|
|
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
|
|
|
|
return CausalLMOutputWithPast(
|
|
|
|
loss=loss,
|
|
|
|
logits=logits,
|
|
|
|
past_key_values=outputs.past_key_values,
|
|
|
|
hidden_states=outputs.hidden_states,
|
|
|
|
attentions=outputs.attentions,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
hidden_states = outputs.get('hidden_states')
|
|
|
|
return {'hidden_states': hidden_states}
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def opt_for_sequence_classification_forward(
|
|
|
|
self: OPTForSequenceClassification,
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
labels: Optional[torch.LongTensor] = None,
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
|
|
stage_index: Optional[List[int]] = None,
|
|
|
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
|
|
|
r"""
|
|
|
|
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward.
|
|
|
|
Please refer to original code of transformers for more details.
|
|
|
|
"""
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
|
|
|
|
input_ids,
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
head_mask=head_mask,
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
use_cache=use_cache,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
return_dict=return_dict,
|
|
|
|
stage_manager=stage_manager,
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
stage_index=stage_index)
|
|
|
|
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
logits = self.score(hidden_states)
|
|
|
|
|
|
|
|
batch_size = input_ids.shape[0] if input_ids is not None else hidden_states.shape[0]
|
|
|
|
|
|
|
|
if self.config.pad_token_id is None:
|
|
|
|
sequence_lengths = -1
|
|
|
|
else:
|
|
|
|
if input_ids is not None:
|
|
|
|
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
|
|
|
else:
|
|
|
|
sequence_lengths = -1
|
|
|
|
logger.warning(
|
|
|
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
|
|
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
|
|
|
|
|
|
|
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
if labels is not None:
|
|
|
|
if self.config.problem_type is None:
|
|
|
|
if self.num_labels == 1:
|
|
|
|
self.config.problem_type = "regression"
|
|
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
|
|
self.config.problem_type = "single_label_classification"
|
|
|
|
else:
|
|
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
|
|
|
|
if self.config.problem_type == "regression":
|
|
|
|
loss_fct = MSELoss()
|
|
|
|
if self.num_labels == 1:
|
|
|
|
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
|
|
|
else:
|
|
|
|
loss = loss_fct(pooled_logits, labels)
|
|
|
|
elif self.config.problem_type == "single_label_classification":
|
|
|
|
loss_fct = CrossEntropyLoss()
|
|
|
|
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
|
|
loss_fct = BCEWithLogitsLoss()
|
|
|
|
loss = loss_fct(pooled_logits, labels)
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
output = (pooled_logits,) + transformer_outputs[1:]
|
|
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
|
|
|
|
return SequenceClassifierOutputWithPast(
|
|
|
|
loss=loss,
|
|
|
|
logits=pooled_logits,
|
|
|
|
past_key_values=transformer_outputs.past_key_values,
|
|
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
|
|
attentions=transformer_outputs.attentions,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
hidden_states = transformer_outputs.get('hidden_states')
|
|
|
|
return {'hidden_states': hidden_states}
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def opt_for_question_answering_forward(
|
|
|
|
self: OPTForQuestionAnswering,
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
start_positions: Optional[torch.LongTensor] = None,
|
|
|
|
end_positions: Optional[torch.LongTensor] = None,
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
|
|
stage_index: Optional[List[int]] = None,
|
|
|
|
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
|
|
|
r"""
|
|
|
|
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward.
|
|
|
|
Please refer to original code of transformers for more details.
|
|
|
|
"""
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
|
|
|
|
input_ids,
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
head_mask=head_mask,
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
use_cache=use_cache,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
return_dict=return_dict,
|
|
|
|
stage_manager=stage_manager,
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
stage_index=stage_index)
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
|
|
|
|
logits = self.qa_outputs(hidden_states)
|
|
|
|
start_logits, end_logits = logits.split(1, dim=-1)
|
|
|
|
start_logits = start_logits.squeeze(-1).contiguous()
|
|
|
|
end_logits = end_logits.squeeze(-1).contiguous()
|
|
|
|
|
|
|
|
total_loss = None
|
|
|
|
if start_positions is not None and end_positions is not None:
|
|
|
|
# If we are on multi-GPU, split add a dimension
|
|
|
|
if len(start_positions.size()) > 1:
|
|
|
|
start_positions = start_positions.squeeze(-1)
|
|
|
|
if len(end_positions.size()) > 1:
|
|
|
|
end_positions = end_positions.squeeze(-1)
|
|
|
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
|
|
|
ignored_index = start_logits.size(1)
|
|
|
|
start_positions = start_positions.clamp(0, ignored_index)
|
|
|
|
end_positions = end_positions.clamp(0, ignored_index)
|
|
|
|
|
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
|
|
|
start_loss = loss_fct(start_logits, start_positions)
|
|
|
|
end_loss = loss_fct(end_logits, end_positions)
|
|
|
|
total_loss = (start_loss + end_loss) / 2
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
output = (start_logits, end_logits) + transformer_outputs[2:]
|
|
|
|
return ((total_loss,) + output) if total_loss is not None else output
|
|
|
|
|
|
|
|
return QuestionAnsweringModelOutput(
|
|
|
|
loss=total_loss,
|
|
|
|
start_logits=start_logits,
|
|
|
|
end_logits=end_logits,
|
|
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
|
|
attentions=transformer_outputs.attentions,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
hidden_states = transformer_outputs.get('hidden_states')
|
|
|
|
return {'hidden_states': hidden_states}
|
|
|
|
|
|
|
|
|
|
|
|
def get_opt_flash_attention_forward():
|
|
|
|
|
|
|
|
from transformers.models.opt.modeling_opt import OPTAttention
|
|
|
|
|
|
|
|
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self: OPTAttention,
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
key_value_states: Optional[torch.Tensor] = None,
|
|
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
layer_head_mask: Optional[torch.Tensor] = None,
|
|
|
|
output_attentions: bool = False,
|
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
|
|
|
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
|
|
# for the decoder
|
|
|
|
is_cross_attention = key_value_states is not None
|
|
|
|
bsz, tgt_len, _ = hidden_states.size()
|
|
|
|
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
|
|
|
|
|
|
|
attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
|
|
|
|
# get query proj
|
|
|
|
query_states = self.q_proj(hidden_states).view(*attention_input_shape)
|
|
|
|
# get key, value proj
|
|
|
|
if is_cross_attention and past_key_value is not None:
|
|
|
|
# reuse k, v, cross_attentions
|
|
|
|
key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape)
|
|
|
|
value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape)
|
|
|
|
elif is_cross_attention:
|
|
|
|
# cross_attentions
|
|
|
|
key_states = self.k_proj(key_value_states).view(*attention_input_shape)
|
|
|
|
value_states = self.v_proj(key_value_states).view(*attention_input_shape)
|
|
|
|
elif past_key_value is not None:
|
|
|
|
# reuse k, v, self_attention
|
|
|
|
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
|
|
|
|
value_states = self.v_proj(hidden_states).view(*attention_input_shape)
|
|
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=1)
|
|
|
|
value_states = torch.cat([past_key_value[1], value_states], dim=1)
|
|
|
|
else:
|
|
|
|
# self_attention
|
|
|
|
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
|
|
|
|
value_states = self.v_proj(hidden_states).view(*attention_input_shape)
|
|
|
|
|
|
|
|
if self.is_decoder:
|
|
|
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
|
|
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
|
|
|
# key/value_states (first "if" case)
|
|
|
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
|
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
|
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
|
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
|
|
past_key_value = (key_states, value_states)
|
|
|
|
|
|
|
|
src_len = key_states.size(1)
|
|
|
|
if layer_head_mask != None:
|
|
|
|
if layer_head_mask.size() != (self.num_heads,):
|
|
|
|
raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
|
|
|
f" {layer_head_mask.size()}")
|
|
|
|
|
|
|
|
flash_attention_mask = None
|
|
|
|
attn_mask_type = AttnMaskType.causal
|
|
|
|
if attention_mask != None:
|
|
|
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
|
|
|
raise ValueError(
|
|
|
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}")
|
|
|
|
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
|
|
|
attn_mask_type = AttnMaskType.paddedcausal
|
|
|
|
|
|
|
|
attention = ColoAttention(embed_dim=self.embed_dim,
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
dropout=self.dropout,
|
|
|
|
scale=self.scaling)
|
|
|
|
attn_output = attention(query_states,
|
|
|
|
key_states,
|
|
|
|
value_states,
|
|
|
|
attn_mask=flash_attention_mask,
|
|
|
|
attn_mask_type=attn_mask_type)
|
|
|
|
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output, None, past_key_value
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
|
|
def get_jit_fused_opt_decoder_layer_forward():
|
|
|
|
|
|
|
|
from transformers.models.opt.modeling_opt import OPTDecoderLayer
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self: OPTDecoderLayer,
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
layer_head_mask: Optional[torch.Tensor] = None,
|
|
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
|
|
output_attentions: Optional[bool] = False,
|
|
|
|
use_cache: Optional[bool] = False,
|
|
|
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
|
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
|
|
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
|
|
|
|
`(encoder_attention_heads,)`.
|
|
|
|
output_attentions (`bool`, *optional*):
|
|
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
|
|
returned tensors for more detail.
|
|
|
|
use_cache (`bool`, *optional*):
|
|
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
|
|
(see `past_key_values`).
|
|
|
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
|
|
"""
|
|
|
|
|
|
|
|
residual = hidden_states
|
|
|
|
|
|
|
|
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
|
|
|
if self.do_layer_norm_before:
|
|
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
|
|
|
|
# Self Attention
|
|
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
past_key_value=past_key_value,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
layer_head_mask=layer_head_mask,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
)
|
|
|
|
|
|
|
|
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
|
|
|
|
|
|
|
# 350m applies layer norm AFTER attention
|
|
|
|
if not self.do_layer_norm_before:
|
|
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
|
|
|
|
# Fully Connected
|
|
|
|
hidden_states_shape = hidden_states.shape
|
|
|
|
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
|
|
|
|
residual = hidden_states
|
|
|
|
|
|
|
|
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
|
|
|
if self.do_layer_norm_before:
|
|
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
|
|
|
|
hidden_states = self.fc1(hidden_states)
|
|
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
|
|
|
|
|
|
hidden_states = self.fc2(hidden_states)
|
|
|
|
|
|
|
|
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training).view(hidden_states_shape)
|
|
|
|
|
|
|
|
# 350m applies layer norm AFTER attention
|
|
|
|
if not self.do_layer_norm_before:
|
|
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
|
|
|
|
if output_attentions:
|
|
|
|
outputs += (self_attn_weights,)
|
|
|
|
|
|
|
|
if use_cache:
|
|
|
|
outputs += (present_key_value,)
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
return forward
|