mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
386 lines
16 KiB
386 lines
16 KiB
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder |
|
from transformers.utils import logging |
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager |
|
from colossalai.shardformer.layer import ColoAttention |
|
|
|
|
|
def _encoder_forward( |
|
encoder: ViTEncoder, |
|
start_idx: int, |
|
end_idx: int, |
|
hidden_states: torch.Tensor, |
|
head_mask: Optional[torch.Tensor] = None, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
stage_manager: PipelineStageManager = None, |
|
) -> Union[tuple, BaseModelOutput]: |
|
for i in range(start_idx, end_idx): |
|
layer_module = encoder.layer[i] |
|
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None |
|
|
|
if encoder.gradient_checkpointing and encoder.training: |
|
layer_outputs = encoder._gradient_checkpointing_func( |
|
layer_module.__call__, |
|
hidden_states, |
|
layer_head_mask, |
|
output_attentions, |
|
) |
|
else: |
|
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) |
|
|
|
hidden_states = layer_outputs[0] |
|
if not stage_manager.is_last_stage(): |
|
return hidden_states |
|
else: |
|
if not return_dict: |
|
return tuple(hidden_states) |
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=None, |
|
attentions=None, |
|
) |
|
|
|
|
|
def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): |
|
from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling |
|
|
|
def pp_forward( |
|
self, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
bool_masked_pos: Optional[torch.BoolTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
interpolate_pos_encoding: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
hidden_states: Optional[torch.FloatTensor] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
r""" |
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): |
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
# Preprocess passed in arguments |
|
if output_attentions: |
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") |
|
output_attentions = False |
|
if output_hidden_states: |
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") |
|
output_hidden_states = False |
|
|
|
# Prepare head mask if needed |
|
# 1.0 in head_mask indicate we keep the head |
|
# attention_probs has shape bsz x n_heads x N x N |
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] |
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] |
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
|
if stage_manager.is_first_stage(): |
|
if pixel_values is None: |
|
raise ValueError("You have to specify pixel_values") |
|
|
|
# TODO(FoolPlayer): maybe have a cleaner way to cast the input (from `ImageProcessor` side?) |
|
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype |
|
if pixel_values.dtype != expected_dtype: |
|
pixel_values = pixel_values.to(expected_dtype) |
|
|
|
embedding_output = self.embeddings( |
|
pixel_values, |
|
bool_masked_pos=bool_masked_pos, |
|
interpolate_pos_encoding=interpolate_pos_encoding, |
|
) |
|
hidden_states = embedding_output |
|
else: |
|
assert ( |
|
hidden_states is not None |
|
), f"Current stage is {stage_manager.stage}, hidden_states should not be None" |
|
|
|
encoder_outputs = _encoder_forward( |
|
encoder=self.encoder, |
|
start_idx=stage_index[0], |
|
end_idx=stage_index[1], |
|
hidden_states=hidden_states, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
stage_manager=stage_manager, |
|
) |
|
if not stage_manager.is_last_stage(): |
|
return {"hidden_states": encoder_outputs} |
|
|
|
sequence_output = encoder_outputs[0] |
|
sequence_output = self.layernorm(sequence_output) |
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|
|
|
if not return_dict: |
|
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) |
|
return head_outputs + encoder_outputs[1:] |
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=sequence_output, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
return pp_forward |
|
|
|
|
|
def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
from transformers.models.vit.modeling_vit import ImageClassifierOutput |
|
|
|
def pp_forward( |
|
self, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
interpolate_pos_encoding: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
hidden_states: Optional[torch.FloatTensor] = None, |
|
) -> Union[tuple, ImageClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if not stage_manager.is_first_stage(): |
|
assert ( |
|
hidden_states is not None |
|
), f"Current stage is {stage_manager.stage}, hidden_states should not be None" |
|
|
|
outputs = self.vit( |
|
pixel_values, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
interpolate_pos_encoding=interpolate_pos_encoding, |
|
return_dict=return_dict, |
|
hidden_states=hidden_states, |
|
) |
|
|
|
# not last stage, return hidden_states |
|
if not stage_manager.is_last_stage(): |
|
return outputs |
|
else: |
|
sequence_output = outputs[0] |
|
|
|
# last stage |
|
logits = self.classifier(sequence_output[:, 0, :]) |
|
|
|
loss = None |
|
if labels is not None: |
|
# move labels to correct device to enable model parallelism |
|
labels = labels.to(logits.device) |
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
|
|
if self.config.problem_type == "regression": |
|
loss_fct = MSELoss() |
|
if self.num_labels == 1: |
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss_fct = BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return ImageClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
return pp_forward |
|
|
|
|
|
def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): |
|
import math |
|
|
|
import torch.nn as nn |
|
from transformers.models.vit.modeling_vit import ImageClassifierOutput, MaskedImageModelingOutput |
|
|
|
def pp_forward( |
|
self, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
bool_masked_pos: Optional[torch.BoolTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
interpolate_pos_encoding: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
hidden_states: Optional[torch.FloatTensor] = None, |
|
) -> Union[tuple, ImageClassifierOutput]: |
|
r""" |
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): |
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). |
|
|
|
Returns: |
|
|
|
Examples: |
|
```python |
|
>>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling |
|
>>> import torch |
|
>>> from PIL import Image |
|
>>> import requests |
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") |
|
>>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k") |
|
|
|
>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 |
|
>>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values |
|
>>> # create random boolean mask of shape (batch_size, num_patches) |
|
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() |
|
|
|
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) |
|
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction |
|
>>> list(reconstructed_pixel_values.shape) |
|
[1, 3, 224, 224] |
|
```""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): |
|
raise ValueError( |
|
"When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " |
|
"the reconstructed image has the same dimensions as the input." |
|
f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." |
|
) |
|
|
|
if not stage_manager.is_first_stage(): |
|
assert ( |
|
hidden_states is not None |
|
), f"Current stage is {stage_manager.stage}, hidden_states should not be None" |
|
|
|
outputs = self.vit( |
|
pixel_values, |
|
bool_masked_pos=bool_masked_pos, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
interpolate_pos_encoding=interpolate_pos_encoding, |
|
return_dict=return_dict, |
|
hidden_states=hidden_states, |
|
) |
|
if not stage_manager.is_last_stage(): |
|
return outputs |
|
else: |
|
sequence_output = outputs[0] |
|
|
|
# Reshape to (batch_size, num_channels, height, width) |
|
sequence_output = sequence_output[:, 1:] |
|
batch_size, sequence_length, num_channels = sequence_output.shape |
|
height = width = math.floor(sequence_length**0.5) |
|
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) |
|
|
|
# Reconstruct pixel values |
|
reconstructed_pixel_values = self.decoder(sequence_output) |
|
|
|
masked_im_loss = None |
|
if bool_masked_pos is not None: |
|
size = self.config.image_size // self.config.patch_size |
|
bool_masked_pos = bool_masked_pos.reshape(-1, size, size) |
|
mask = ( |
|
bool_masked_pos.repeat_interleave(self.config.patch_size, 1) |
|
.repeat_interleave(self.config.patch_size, 2) |
|
.unsqueeze(1) |
|
.contiguous() |
|
) |
|
reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") |
|
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels |
|
|
|
if not return_dict: |
|
output = (reconstructed_pixel_values,) + outputs[1:] |
|
return ((masked_im_loss,) + output) if masked_im_loss is not None else output |
|
|
|
return MaskedImageModelingOutput( |
|
loss=masked_im_loss, |
|
reconstruction=reconstructed_pixel_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
return pp_forward |
|
|
|
|
|
def get_vit_flash_self_attention_forward(): |
|
from transformers.models.vit.modeling_vit import ViTSelfAttention |
|
|
|
def forward( |
|
self: ViTSelfAttention, |
|
hidden_states: torch.Tensor, |
|
head_mask: Optional[torch.Tensor] = None, |
|
output_attentions: bool = False, |
|
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: |
|
assert head_mask is None, "head_mask is not supported for FlashAttention" |
|
mixed_query_layer = self.query(hidden_states) |
|
|
|
key_layer = self.transpose_for_scores(self.key(hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
query_layer = self.transpose_for_scores(mixed_query_layer) |
|
|
|
dropout_p = self.dropout.p if self.training else 0.0 |
|
context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p) |
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
|
context_layer = context_layer.view(new_context_layer_shape) |
|
|
|
outputs = (context_layer, None) if output_attentions else (context_layer,) |
|
|
|
return outputs |
|
|
|
return forward |
|
|
|
|
|
def get_jit_fused_vit_output_forward(): |
|
from transformers.models.vit.modeling_vit import ViTOutput |
|
|
|
def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) |
|
return hidden_states |
|
|
|
return forward |
|
|
|
|
|
def get_jit_fused_vit_intermediate_forward(): |
|
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states, bias = self.dense(hidden_states) |
|
hidden_states = JitGeLUFunction.apply(hidden_states, bias) |
|
|
|
return hidden_states |
|
|
|
return forward
|
|
|