ColossalAI/colossalai/shardformer/modeling/opt.py

1006 lines
44 KiB
Python

import random
from typing import List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from transformers.models.opt.modeling_opt import (
OPTForCausalLM,
OPTForQuestionAnswering,
OPTForSequenceClassification,
OPTModel,
)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
logger = logging.get_logger(__name__)
def _get_attention_mask(
self: OPTModel,
shard_config: ShardConfig,
hidden_states: torch.Tensor,
past_key_values_length: int,
attention_mask: Optional[torch.FloatTensor],
):
batch_size, seq_length = hidden_states.shape[:2]
mask_seq_length = past_key_values_length + seq_length
if shard_config.enable_flash_attention:
attention_mask = ColoAttention.prepare_attn_kwargs(
(batch_size, 1, seq_length, mask_seq_length),
hidden_states.dtype,
hidden_states.device,
attention_mask,
is_causal=True,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
)
return attention_mask
class OPTPipelineForwards:
"""
This class serves as a micro library for forward function substitution of OPT models
under pipeline setting.
"""
@staticmethod
def opt_model_forward(
self: OPTModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
"""
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import logging
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
decoder = self.decoder
if stage_manager.is_first_stage():
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
batch_size, seq_length = input_shape
if inputs_embeds is None:
inputs_embeds = decoder.embed_tokens(input_ids)
if decoder.project_in is not None:
inputs_embeds = decoder.project_in(inputs_embeds)
device = input_ids.device if input_ids is not None else inputs_embeds.device
inputs_embeds.dtype
hidden_states = inputs_embeds
else:
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for intermediate stages.")
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
hidden_states.dtype
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if self.decoder._use_flash_attention_2:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if attention_mask is None
else attention_mask
)
else:
# 4d mask is passed through the layers
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
causal_attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, hidden_states, past_key_values_length
)
if stage_manager.is_first_stage():
causal_attention_mask = _get_attention_mask(
self,
shard_config,
inputs_embeds,
past_key_values_length,
attention_mask,
)
pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length)
hidden_states = inputs_embeds + pos_embeds
else:
causal_attention_mask = _get_attention_mask(
self,
shard_config,
hidden_states,
past_key_values_length,
attention_mask,
)
if decoder.gradient_checkpointing and decoder.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.")
past_key_values = None
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(decoder.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
start_idx, end_idx = stage_index[0], stage_index[1]
torch.cuda.set_device(device)
for idx in range(start_idx, end_idx):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
decoder_layer = decoder.layers[idx]
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if decoder.training and (dropout_probability < decoder.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if decoder.gradient_checkpointing and decoder.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if stage_manager.is_last_stage():
if decoder.final_layer_norm is not None:
hidden_states = decoder.final_layer_norm(hidden_states)
if decoder.project_out is not None:
hidden_states = decoder.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
else:
return {"hidden_states": hidden_states}
@staticmethod
def opt_for_causal_lm_forward(
self: OPTForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward.
Please refer to original code of transformers for more details.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = OPTPipelineForwards.opt_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
if stage_manager.is_last_stage():
logits = self.lm_head(outputs[0]).contiguous()
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
else:
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def opt_for_sequence_classification_forward(
self: OPTForSequenceClassification,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward.
Please refer to original code of transformers for more details.
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = OPTPipelineForwards.opt_model_forward(
self.model,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
batch_size = input_ids.shape[0] if input_ids is not None else hidden_states.shape[0]
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def opt_for_question_answering_forward(
self: OPTForQuestionAnswering,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward.
Please refer to original code of transformers for more details.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = OPTPipelineForwards.opt_model_forward(
self.model,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + transformer_outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
def get_opt_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.opt.modeling_opt import OPTAttention
def forward(
self: OPTAttention,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[dict] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
dropout_p = self.dropout if self.training else 0.0
attn_output = ColoAttention.attention(
query_states,
key_states,
value_states,
**attention_mask,
dropout_p=dropout_p,
scale=self.scaling,
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
return forward
def get_opt_decoder_forward_for_flash_attention(shard_config: ShardConfig):
from transformers.models.opt.modeling_opt import OPTDecoder
def forward(
self: OPTDecoder,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
causal_attention_mask = _get_attention_mask(
self, shard_config, inputs_embeds, past_key_values_length, attention_mask
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return forward
def get_jit_fused_opt_decoder_layer_forward():
from transformers.models.opt.modeling_opt import OPTDecoderLayer
def forward(
self: OPTDecoderLayer,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training).view(hidden_states_shape)
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
def forward(
self: OPTForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, OPTForCausalLM
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0]).contiguous()
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return forward