|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
|
|
|
|
from transformers.utils import logging
|
|
|
|
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
from colossalai.shardformer.layer import ColoAttention
|
|
|
|
|
|
|
|
|
|
|
|
def _encoder_forward(
|
|
|
|
encoder: ViTEncoder,
|
|
|
|
start_idx: int,
|
|
|
|
end_idx: int,
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
|
|
output_attentions: bool = False,
|
|
|
|
output_hidden_states: bool = False,
|
|
|
|
return_dict: bool = True,
|
|
|
|
stage_manager: PipelineStageManager = None,
|
|
|
|
) -> Union[tuple, BaseModelOutput]:
|
|
|
|
for i in range(start_idx, end_idx):
|
|
|
|
layer_module = encoder.layer[i]
|
|
|
|
|
|
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
|
|
|
|
|
|
if encoder.gradient_checkpointing and encoder.training:
|
|
|
|
layer_outputs = encoder._gradient_checkpointing_func(
|
|
|
|
layer_module.__call__,
|
|
|
|
hidden_states,
|
|
|
|
layer_head_mask,
|
|
|
|
output_attentions,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
|
|
|
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if not stage_manager.is_last_stage():
|
|
|
|
return hidden_states
|
|
|
|
else:
|
|
|
|
if not return_dict:
|
|
|
|
return tuple(hidden_states)
|
|
|
|
return BaseModelOutput(
|
|
|
|
last_hidden_state=hidden_states,
|
|
|
|
hidden_states=None,
|
|
|
|
attentions=None,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
|
|
|
|
from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling
|
|
|
|
|
|
|
|
def pp_forward(
|
|
|
|
self,
|
|
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
|
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
|
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
interpolate_pos_encoding: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
|
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
|
|
r"""
|
|
|
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
|
|
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
|
|
|
"""
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
# Preprocess passed in arguments
|
|
|
|
if output_attentions:
|
|
|
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
|
|
|
output_attentions = False
|
|
|
|
if output_hidden_states:
|
|
|
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
|
|
|
output_hidden_states = False
|
|
|
|
|
|
|
|
# Prepare head mask if needed
|
|
|
|
# 1.0 in head_mask indicate we keep the head
|
|
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
|
|
|
|
|
|
|
if stage_manager.is_first_stage():
|
|
|
|
if pixel_values is None:
|
|
|
|
raise ValueError("You have to specify pixel_values")
|
|
|
|
|
|
|
|
# TODO(FoolPlayer): maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
|
|
|
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
|
|
|
if pixel_values.dtype != expected_dtype:
|
|
|
|
pixel_values = pixel_values.to(expected_dtype)
|
|
|
|
|
|
|
|
embedding_output = self.embeddings(
|
|
|
|
pixel_values,
|
|
|
|
bool_masked_pos=bool_masked_pos,
|
|
|
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
|
|
)
|
|
|
|
hidden_states = embedding_output
|
|
|
|
else:
|
|
|
|
assert (
|
|
|
|
hidden_states is not None
|
|
|
|
), f"Current stage is {stage_manager.stage}, hidden_states should not be None"
|
|
|
|
|
|
|
|
encoder_outputs = _encoder_forward(
|
|
|
|
encoder=self.encoder,
|
|
|
|
start_idx=stage_index[0],
|
|
|
|
end_idx=stage_index[1],
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
head_mask=head_mask,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
return_dict=return_dict,
|
|
|
|
stage_manager=stage_manager,
|
|
|
|
)
|
|
|
|
if not stage_manager.is_last_stage():
|
|
|
|
return {"hidden_states": encoder_outputs}
|
|
|
|
|
|
|
|
sequence_output = encoder_outputs[0]
|
|
|
|
sequence_output = self.layernorm(sequence_output)
|
|
|
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
|
|
|
return head_outputs + encoder_outputs[1:]
|
|
|
|
|
|
|
|
return BaseModelOutputWithPooling(
|
|
|
|
last_hidden_state=sequence_output,
|
|
|
|
pooler_output=pooled_output,
|
|
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
|
|
attentions=encoder_outputs.attentions,
|
|
|
|
)
|
|
|
|
|
|
|
|
return pp_forward
|
|
|
|
|
|
|
|
|
|
|
|
def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
|
|
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from transformers.models.vit.modeling_vit import ImageClassifierOutput
|
|
|
|
|
|
|
|
def pp_forward(
|
|
|
|
self,
|
|
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
|
|
labels: Optional[torch.Tensor] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
interpolate_pos_encoding: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
|
|
) -> Union[tuple, ImageClassifierOutput]:
|
|
|
|
r"""
|
|
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
|
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
|
|
"""
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
if not stage_manager.is_first_stage():
|
|
|
|
assert (
|
|
|
|
hidden_states is not None
|
|
|
|
), f"Current stage is {stage_manager.stage}, hidden_states should not be None"
|
|
|
|
|
|
|
|
outputs = self.vit(
|
|
|
|
pixel_values,
|
|
|
|
head_mask=head_mask,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
|
|
return_dict=return_dict,
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
)
|
|
|
|
|
|
|
|
# not last stage, return hidden_states
|
|
|
|
if not stage_manager.is_last_stage():
|
|
|
|
return outputs
|
|
|
|
else:
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
|
|
|
|
# last stage
|
|
|
|
logits = self.classifier(sequence_output[:, 0, :])
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
if labels is not None:
|
|
|
|
# move labels to correct device to enable model parallelism
|
|
|
|
labels = labels.to(logits.device)
|
|
|
|
if self.config.problem_type is None:
|
|
|
|
if self.num_labels == 1:
|
|
|
|
self.config.problem_type = "regression"
|
|
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
|
|
self.config.problem_type = "single_label_classification"
|
|
|
|
else:
|
|
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
|
|
|
|
if self.config.problem_type == "regression":
|
|
|
|
loss_fct = MSELoss()
|
|
|
|
if self.num_labels == 1:
|
|
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
|
|
|
else:
|
|
|
|
loss = loss_fct(logits, labels)
|
|
|
|
elif self.config.problem_type == "single_label_classification":
|
|
|
|
loss_fct = CrossEntropyLoss()
|
|
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
|
|
loss_fct = BCEWithLogitsLoss()
|
|
|
|
loss = loss_fct(logits, labels)
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
output = (logits,) + outputs[1:]
|
|
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
|
|
|
|
return ImageClassifierOutput(
|
|
|
|
loss=loss,
|
|
|
|
logits=logits,
|
|
|
|
hidden_states=outputs.hidden_states,
|
|
|
|
attentions=outputs.attentions,
|
|
|
|
)
|
|
|
|
|
|
|
|
return pp_forward
|
|
|
|
|
|
|
|
|
|
|
|
def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
|
|
|
|
import math
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
from transformers.models.vit.modeling_vit import ImageClassifierOutput, MaskedImageModelingOutput
|
|
|
|
|
|
|
|
def pp_forward(
|
|
|
|
self,
|
|
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
|
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
|
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
interpolate_pos_encoding: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
|
|
) -> Union[tuple, ImageClassifierOutput]:
|
|
|
|
r"""
|
|
|
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
|
|
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
```python
|
|
|
|
>>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling
|
|
|
|
>>> import torch
|
|
|
|
>>> from PIL import Image
|
|
|
|
>>> import requests
|
|
|
|
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
|
|
|
|
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
|
|
|
>>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
|
|
|
|
|
|
|
|
>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
|
|
|
|
>>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
|
|
|
|
>>> # create random boolean mask of shape (batch_size, num_patches)
|
|
|
|
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
|
|
|
|
|
|
|
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
|
|
|
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
|
|
|
|
>>> list(reconstructed_pixel_values.shape)
|
|
|
|
[1, 3, 224, 224]
|
|
|
|
```"""
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride):
|
|
|
|
raise ValueError(
|
|
|
|
"When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that "
|
|
|
|
"the reconstructed image has the same dimensions as the input."
|
|
|
|
f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}."
|
|
|
|
)
|
|
|
|
|
|
|
|
if not stage_manager.is_first_stage():
|
|
|
|
assert (
|
|
|
|
hidden_states is not None
|
|
|
|
), f"Current stage is {stage_manager.stage}, hidden_states should not be None"
|
|
|
|
|
|
|
|
outputs = self.vit(
|
|
|
|
pixel_values,
|
|
|
|
bool_masked_pos=bool_masked_pos,
|
|
|
|
head_mask=head_mask,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
|
|
return_dict=return_dict,
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
)
|
|
|
|
if not stage_manager.is_last_stage():
|
|
|
|
return outputs
|
|
|
|
else:
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
|
|
|
|
# Reshape to (batch_size, num_channels, height, width)
|
|
|
|
sequence_output = sequence_output[:, 1:]
|
|
|
|
batch_size, sequence_length, num_channels = sequence_output.shape
|
|
|
|
height = width = math.floor(sequence_length**0.5)
|
|
|
|
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
|
|
|
|
|
|
|
|
# Reconstruct pixel values
|
|
|
|
reconstructed_pixel_values = self.decoder(sequence_output)
|
|
|
|
|
|
|
|
masked_im_loss = None
|
|
|
|
if bool_masked_pos is not None:
|
|
|
|
size = self.config.image_size // self.config.patch_size
|
|
|
|
bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
|
|
|
|
mask = (
|
|
|
|
bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
|
|
|
|
.repeat_interleave(self.config.patch_size, 2)
|
|
|
|
.unsqueeze(1)
|
|
|
|
.contiguous()
|
|
|
|
)
|
|
|
|
reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
|
|
|
|
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
output = (reconstructed_pixel_values,) + outputs[1:]
|
|
|
|
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
|
|
|
|
|
|
|
return MaskedImageModelingOutput(
|
|
|
|
loss=masked_im_loss,
|
|
|
|
reconstruction=reconstructed_pixel_values,
|
|
|
|
hidden_states=outputs.hidden_states,
|
|
|
|
attentions=outputs.attentions,
|
|
|
|
)
|
|
|
|
|
|
|
|
return pp_forward
|
|
|
|
|
|
|
|
|
|
|
|
def get_vit_flash_self_attention_forward():
|
|
|
|
from transformers.models.vit.modeling_vit import ViTSelfAttention
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self: ViTSelfAttention,
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
|
|
output_attentions: bool = False,
|
|
|
|
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
|
|
|
assert head_mask is None, "head_mask is not supported for FlashAttention"
|
|
|
|
mixed_query_layer = self.query(hidden_states)
|
|
|
|
|
|
|
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
|
|
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
|
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
|
|
|
|
|
|
|
dropout_p = self.dropout.p if self.training else 0.0
|
|
|
|
context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p)
|
|
|
|
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
|
|
context_layer = context_layer.view(new_context_layer_shape)
|
|
|
|
|
|
|
|
outputs = (context_layer, None) if output_attentions else (context_layer,)
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
|
|
def get_jit_fused_vit_output_forward():
|
|
|
|
from transformers.models.vit.modeling_vit import ViTOutput
|
|
|
|
|
|
|
|
def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
|
|
hidden_states = self.dense(hidden_states)
|
|
|
|
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
|
|
def get_jit_fused_vit_intermediate_forward():
|
|
|
|
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
|
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
|
hidden_states, bias = self.dense(hidden_states)
|
|
|
|
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
|
|
|
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
return forward
|