mirror of https://github.com/hpcaitech/ColossalAI
338 lines
14 KiB
Python
338 lines
14 KiB
Python
|
import logging
|
||
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
|
||
|
|
||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||
|
|
||
|
|
||
|
def _encoder_forward(
|
||
|
encoder: ViTEncoder,
|
||
|
start_idx: int,
|
||
|
end_idx: int,
|
||
|
hidden_states: torch.Tensor,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
return_dict: bool = True,
|
||
|
stage_manager: PipelineStageManager = None,
|
||
|
) -> Union[tuple, BaseModelOutput]:
|
||
|
|
||
|
for i in range(start_idx, end_idx):
|
||
|
layer_module = encoder.layer[i]
|
||
|
|
||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||
|
|
||
|
if encoder.gradient_checkpointing and encoder.training:
|
||
|
|
||
|
def create_custom_forward(module):
|
||
|
|
||
|
def custom_forward(*inputs):
|
||
|
return module(*inputs, False)
|
||
|
|
||
|
return custom_forward
|
||
|
|
||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||
|
create_custom_forward(layer_module),
|
||
|
hidden_states,
|
||
|
layer_head_mask,
|
||
|
)
|
||
|
else:
|
||
|
layer_outputs = layer_module(hidden_states, layer_head_mask, False)
|
||
|
|
||
|
hidden_states = layer_outputs[0]
|
||
|
if not stage_manager.is_last_stage():
|
||
|
return hidden_states
|
||
|
else:
|
||
|
if not return_dict:
|
||
|
return tuple(hidden_states)
|
||
|
return BaseModelOutput(
|
||
|
last_hidden_state=hidden_states,
|
||
|
hidden_states=None,
|
||
|
attentions=None,
|
||
|
)
|
||
|
|
||
|
|
||
|
def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
|
||
|
|
||
|
from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling
|
||
|
|
||
|
def pp_forward(
|
||
|
self,
|
||
|
pixel_values: Optional[torch.Tensor] = None,
|
||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
interpolate_pos_encoding: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||
|
r"""
|
||
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
||
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||
|
"""
|
||
|
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (output_hidden_states
|
||
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
if output_attentions is not None:
|
||
|
logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.')
|
||
|
output_attentions = None
|
||
|
if output_hidden_states is not None:
|
||
|
logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.')
|
||
|
output_hidden_states = None
|
||
|
|
||
|
# Prepare head mask if needed
|
||
|
# 1.0 in head_mask indicate we keep the head
|
||
|
# attention_probs has shape bsz x n_heads x N x N
|
||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||
|
|
||
|
if stage_manager.is_first_stage():
|
||
|
if pixel_values is None:
|
||
|
raise ValueError("You have to specify pixel_values")
|
||
|
|
||
|
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
||
|
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
||
|
if pixel_values.dtype != expected_dtype:
|
||
|
pixel_values = pixel_values.to(expected_dtype)
|
||
|
|
||
|
embedding_output = self.embeddings(pixel_values,
|
||
|
bool_masked_pos=bool_masked_pos,
|
||
|
interpolate_pos_encoding=interpolate_pos_encoding)
|
||
|
else:
|
||
|
assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
|
||
|
|
||
|
# Go through encoder
|
||
|
if not stage_manager.is_last_stage():
|
||
|
hidden_states = _encoder_forward(
|
||
|
encoder=self.encoder,
|
||
|
start_idx=stage_index[0],
|
||
|
end_idx=stage_index[1],
|
||
|
hidden_states=embedding_output,
|
||
|
head_mask=head_mask,
|
||
|
return_dict=return_dict,
|
||
|
stage_manager=stage_manager,
|
||
|
)
|
||
|
return {'hidden_states': hidden_states}
|
||
|
else:
|
||
|
encoder_outputs = _encoder_forward(
|
||
|
encoder=self.encoder,
|
||
|
start_idx=stage_index[0],
|
||
|
end_idx=stage_index[1],
|
||
|
hidden_states=hidden_states,
|
||
|
head_mask=head_mask,
|
||
|
return_dict=return_dict,
|
||
|
stage_manager=stage_manager,
|
||
|
)
|
||
|
|
||
|
# Go through rest layers
|
||
|
sequence_output = encoder_outputs[0]
|
||
|
sequence_output = self.layernorm(sequence_output)
|
||
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||
|
|
||
|
if not return_dict:
|
||
|
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
||
|
return head_outputs + encoder_outputs[1:]
|
||
|
|
||
|
return BaseModelOutputWithPooling(
|
||
|
last_hidden_state=sequence_output,
|
||
|
pooler_output=pooled_output,
|
||
|
hidden_states=encoder_outputs.hidden_states,
|
||
|
attentions=encoder_outputs.attentions,
|
||
|
)
|
||
|
|
||
|
return pp_forward
|
||
|
|
||
|
|
||
|
def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
|
||
|
|
||
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||
|
from transformers.models.vit.modeling_vit import ImageClassifierOutput
|
||
|
|
||
|
def pp_forward(
|
||
|
self,
|
||
|
pixel_values: Optional[torch.Tensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
labels: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
interpolate_pos_encoding: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||
|
) -> Union[tuple, ImageClassifierOutput]:
|
||
|
r"""
|
||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||
|
"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
if not stage_manager.is_first_stage():
|
||
|
assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
|
||
|
|
||
|
outputs = self.vit(
|
||
|
pixel_values,
|
||
|
head_mask=head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||
|
return_dict=return_dict,
|
||
|
hidden_states=hidden_states,
|
||
|
)
|
||
|
|
||
|
# not last stage, return hidden_states
|
||
|
if not stage_manager.is_last_stage():
|
||
|
return outputs
|
||
|
else:
|
||
|
sequence_output = outputs[0]
|
||
|
|
||
|
# last stage
|
||
|
logits = self.classifier(sequence_output[:, 0, :])
|
||
|
|
||
|
loss = None
|
||
|
if labels is not None:
|
||
|
# move labels to correct device to enable model parallelism
|
||
|
labels = labels.to(logits.device)
|
||
|
if self.config.problem_type is None:
|
||
|
if self.num_labels == 1:
|
||
|
self.config.problem_type = "regression"
|
||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||
|
self.config.problem_type = "single_label_classification"
|
||
|
else:
|
||
|
self.config.problem_type = "multi_label_classification"
|
||
|
|
||
|
if self.config.problem_type == "regression":
|
||
|
loss_fct = MSELoss()
|
||
|
if self.num_labels == 1:
|
||
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||
|
else:
|
||
|
loss = loss_fct(logits, labels)
|
||
|
elif self.config.problem_type == "single_label_classification":
|
||
|
loss_fct = CrossEntropyLoss()
|
||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||
|
elif self.config.problem_type == "multi_label_classification":
|
||
|
loss_fct = BCEWithLogitsLoss()
|
||
|
loss = loss_fct(logits, labels)
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (logits,) + outputs[1:]
|
||
|
return ((loss,) + output) if loss is not None else output
|
||
|
|
||
|
return ImageClassifierOutput(
|
||
|
loss=loss,
|
||
|
logits=logits,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|
||
|
|
||
|
return pp_forward
|
||
|
|
||
|
|
||
|
def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
|
||
|
|
||
|
import math
|
||
|
|
||
|
import torch.nn as nn
|
||
|
from transformers.models.vit.modeling_vit import ImageClassifierOutput, MaskedImageModelingOutput
|
||
|
|
||
|
def pp_forward(
|
||
|
self,
|
||
|
pixel_values: Optional[torch.Tensor] = None,
|
||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
interpolate_pos_encoding: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||
|
) -> Union[tuple, ImageClassifierOutput]:
|
||
|
r"""
|
||
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
||
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
Examples:
|
||
|
```python
|
||
|
>>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling
|
||
|
>>> import torch
|
||
|
>>> from PIL import Image
|
||
|
>>> import requests
|
||
|
|
||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||
|
|
||
|
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
||
|
>>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
|
||
|
|
||
|
>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
|
||
|
>>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
|
||
|
>>> # create random boolean mask of shape (batch_size, num_patches)
|
||
|
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
||
|
|
||
|
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
||
|
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
|
||
|
>>> list(reconstructed_pixel_values.shape)
|
||
|
[1, 3, 224, 224]
|
||
|
```"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride):
|
||
|
raise ValueError(
|
||
|
"When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that "
|
||
|
"the reconstructed image has the same dimensions as the input."
|
||
|
f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}.")
|
||
|
|
||
|
if not stage_manager.is_first_stage():
|
||
|
assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
|
||
|
|
||
|
outputs = self.vit(pixel_values,
|
||
|
bool_masked_pos=bool_masked_pos,
|
||
|
head_mask=head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||
|
return_dict=return_dict,
|
||
|
hidden_states=hidden_states)
|
||
|
if not stage_manager.is_last_stage():
|
||
|
return outputs
|
||
|
else:
|
||
|
sequence_output = outputs[0]
|
||
|
|
||
|
# Reshape to (batch_size, num_channels, height, width)
|
||
|
sequence_output = sequence_output[:, 1:]
|
||
|
batch_size, sequence_length, num_channels = sequence_output.shape
|
||
|
height = width = math.floor(sequence_length**0.5)
|
||
|
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
|
||
|
|
||
|
# Reconstruct pixel values
|
||
|
reconstructed_pixel_values = self.decoder(sequence_output)
|
||
|
|
||
|
masked_im_loss = None
|
||
|
if bool_masked_pos is not None:
|
||
|
size = self.config.image_size // self.config.patch_size
|
||
|
bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
|
||
|
mask = (bool_masked_pos.repeat_interleave(self.config.patch_size,
|
||
|
1).repeat_interleave(self.config.patch_size,
|
||
|
2).unsqueeze(1).contiguous())
|
||
|
reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
|
||
|
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (reconstructed_pixel_values,) + outputs[1:]
|
||
|
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
||
|
|
||
|
return MaskedImageModelingOutput(
|
||
|
loss=masked_im_loss,
|
||
|
reconstruction=reconstructed_pixel_values,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|
||
|
|
||
|
return pp_forward
|