mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support pipeline base vit model (#4284)
* Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * support base vit pipeline * support vit downstream model * fix vit shard test * modify hidden states return type --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>pull/4445/head
parent
083d7da33d
commit
b3f5d7a3ba
|
@ -0,0 +1,337 @@
|
|||
import logging
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
|
||||
def _encoder_forward(
|
||||
encoder: ViTEncoder,
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
layer_module = encoder.layer[i]
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if encoder.gradient_checkpointing and encoder.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, False)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer_module),
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, False)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if not stage_manager.is_last_stage():
|
||||
return hidden_states
|
||||
else:
|
||||
if not return_dict:
|
||||
return tuple(hidden_states)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
|
||||
|
||||
from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling
|
||||
|
||||
def pp_forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
||||
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||||
"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if output_attentions is not None:
|
||||
logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.')
|
||||
output_attentions = None
|
||||
if output_hidden_states is not None:
|
||||
logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
||||
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
||||
if pixel_values.dtype != expected_dtype:
|
||||
pixel_values = pixel_values.to(expected_dtype)
|
||||
|
||||
embedding_output = self.embeddings(pixel_values,
|
||||
bool_masked_pos=bool_masked_pos,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
else:
|
||||
assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
|
||||
|
||||
# Go through encoder
|
||||
if not stage_manager.is_last_stage():
|
||||
hidden_states = _encoder_forward(
|
||||
encoder=self.encoder,
|
||||
start_idx=stage_index[0],
|
||||
end_idx=stage_index[1],
|
||||
hidden_states=embedding_output,
|
||||
head_mask=head_mask,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
)
|
||||
return {'hidden_states': hidden_states}
|
||||
else:
|
||||
encoder_outputs = _encoder_forward(
|
||||
encoder=self.encoder,
|
||||
start_idx=stage_index[0],
|
||||
end_idx=stage_index[1],
|
||||
hidden_states=hidden_states,
|
||||
head_mask=head_mask,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
)
|
||||
|
||||
# Go through rest layers
|
||||
sequence_output = encoder_outputs[0]
|
||||
sequence_output = self.layernorm(sequence_output)
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
||||
return head_outputs + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
return pp_forward
|
||||
|
||||
|
||||
def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
|
||||
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.models.vit.modeling_vit import ImageClassifierOutput
|
||||
|
||||
def pp_forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
) -> Union[tuple, ImageClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if not stage_manager.is_first_stage():
|
||||
assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
|
||||
|
||||
outputs = self.vit(
|
||||
pixel_values,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# not last stage, return hidden_states
|
||||
if not stage_manager.is_last_stage():
|
||||
return outputs
|
||||
else:
|
||||
sequence_output = outputs[0]
|
||||
|
||||
# last stage
|
||||
logits = self.classifier(sequence_output[:, 0, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return ImageClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return pp_forward
|
||||
|
||||
|
||||
def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
|
||||
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.vit.modeling_vit import ImageClassifierOutput, MaskedImageModelingOutput
|
||||
|
||||
def pp_forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
) -> Union[tuple, ImageClassifierOutput]:
|
||||
r"""
|
||||
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
||||
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||||
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
||||
>>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
|
||||
|
||||
>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
|
||||
>>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
|
||||
>>> # create random boolean mask of shape (batch_size, num_patches)
|
||||
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
||||
|
||||
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
|
||||
>>> list(reconstructed_pixel_values.shape)
|
||||
[1, 3, 224, 224]
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride):
|
||||
raise ValueError(
|
||||
"When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that "
|
||||
"the reconstructed image has the same dimensions as the input."
|
||||
f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}.")
|
||||
|
||||
if not stage_manager.is_first_stage():
|
||||
assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
|
||||
|
||||
outputs = self.vit(pixel_values,
|
||||
bool_masked_pos=bool_masked_pos,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
hidden_states=hidden_states)
|
||||
if not stage_manager.is_last_stage():
|
||||
return outputs
|
||||
else:
|
||||
sequence_output = outputs[0]
|
||||
|
||||
# Reshape to (batch_size, num_channels, height, width)
|
||||
sequence_output = sequence_output[:, 1:]
|
||||
batch_size, sequence_length, num_channels = sequence_output.shape
|
||||
height = width = math.floor(sequence_length**0.5)
|
||||
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
|
||||
|
||||
# Reconstruct pixel values
|
||||
reconstructed_pixel_values = self.decoder(sequence_output)
|
||||
|
||||
masked_im_loss = None
|
||||
if bool_masked_pos is not None:
|
||||
size = self.config.image_size // self.config.patch_size
|
||||
bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
|
||||
mask = (bool_masked_pos.repeat_interleave(self.config.patch_size,
|
||||
1).repeat_interleave(self.config.patch_size,
|
||||
2).unsqueeze(1).contiguous())
|
||||
reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
|
||||
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
|
||||
|
||||
if not return_dict:
|
||||
output = (reconstructed_pixel_values,) + outputs[1:]
|
||||
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
||||
|
||||
return MaskedImageModelingOutput(
|
||||
loss=masked_im_loss,
|
||||
reconstruction=reconstructed_pixel_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return pp_forward
|
|
@ -75,6 +75,14 @@ _POLICY_LIST = {
|
|||
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"),
|
||||
|
||||
# ViT
|
||||
"transformers.models.vit.modeling_vit.ViTModel":
|
||||
PolicyLocation(file_name="vit", class_name="ViTModelPolicy"),
|
||||
"transformers.models.vit.modeling_vit.ViTForImageClassification":
|
||||
PolicyLocation(file_name="vit", class_name="ViTForImageClassificationPolicy"),
|
||||
"transformers.models.vit.modeling_vit.ViTForMaskedImageModeling":
|
||||
PolicyLocation(file_name="vit", class_name="ViTForMaskedImageModelingPolicy"),
|
||||
|
||||
# OPT
|
||||
"transformers.models.opt.modeling_opt.OPTModel":
|
||||
PolicyLocation(file_name="opt", class_name="OPTModelPolicy"),
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
from typing import Dict, Union
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from ..modeling.vit import (
|
||||
ViTForImageClassification_pipeline_forward,
|
||||
ViTForMaskedImageModeling_pipeline_forward,
|
||||
ViTModel_pipeline_forward,
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['ViTPolicy']
|
||||
__all__ = ['ViTPolicy', 'ViTModelPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy']
|
||||
|
||||
|
||||
class ViTPolicy(Policy):
|
||||
|
@ -15,96 +21,203 @@ class ViTPolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
|
||||
|
||||
base_policy = {
|
||||
ViTEmbeddings:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForReplicatedInput,
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
)
|
||||
]),
|
||||
ViTLayer:
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
])
|
||||
|
||||
policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={
|
||||
"attention.attention.num_attention_heads":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"attention.attention.all_head_size":
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
target_module=Linear1D_Col,
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.key",
|
||||
target_module=Linear1D_Col,
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.value",
|
||||
target_module=Linear1D_Col,
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=Linear1D_Row,
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=Linear1D_Col,
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=Linear1D_Row,
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
]),
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
base_policy[ViTAttention].sub_module_replacement.extend([
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layernorm_before",
|
||||
target_module=FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layernorm_after",
|
||||
target_module=FusedLayerNorm,
|
||||
)
|
||||
])
|
||||
base_policy[ViTModel].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layernorm",
|
||||
target_module=FusedLayerNorm,
|
||||
))
|
||||
|
||||
return base_policy
|
||||
return policy
|
||||
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
||||
|
||||
if self.model.__class__.__name__ == 'ViTModel':
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.vit
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embeddings)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
||||
return held_layers
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict):
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == 'ViTModel':
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.vit
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {'forward': pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=model_cls)
|
||||
|
||||
|
||||
# ViTModel
|
||||
class ViTModelPolicy(ViTPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.vit.modeling_vit import ViTModel
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
||||
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.layernorm)
|
||||
held_layers.append(module.pooler)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
||||
# ViTForImageClassification
|
||||
class ViTForImageClassificationPolicy(ViTPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel
|
||||
|
||||
policy = super().module_policy()
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
new_item = {
|
||||
ViTForImageClassification:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.shard_config.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
|
||||
self.set_pipeline_forward(model_cls=ViTForImageClassification,
|
||||
pipeline_forward=ViTForImageClassification_pipeline_forward,
|
||||
policy=policy)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
||||
|
||||
module = self.model.vit
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.layernorm)
|
||||
held_layers.append(self.model.classifier)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
||||
# ViTForMaskedImageModeling
|
||||
class ViTForMaskedImageModelingPolicy(ViTPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
|
||||
self.set_pipeline_forward(model_cls=ViTForMaskedImageModeling,
|
||||
pipeline_forward=ViTForMaskedImageModeling_pipeline_forward,
|
||||
policy=policy)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
||||
|
||||
module = self.model.vit
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.layernorm)
|
||||
held_layers.append(self.model.decoder)
|
||||
|
||||
return held_layers
|
||||
|
|
|
@ -5,3 +5,4 @@ from .gpt import *
|
|||
from .llama import *
|
||||
from .opt import *
|
||||
from .t5 import *
|
||||
from .vit import *
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
import torch
|
||||
import transformers
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
# ===============================
|
||||
# Register single-sentence VIT
|
||||
# ===============================
|
||||
|
||||
config = transformers.ViTConfig(
|
||||
num_hidden_layers=4,
|
||||
# hidden_size=128,
|
||||
# intermediate_size=256,
|
||||
num_attention_heads=4)
|
||||
|
||||
|
||||
# define data gen function
|
||||
def data_gen():
|
||||
pixel_values = torch.randn(1, 3, 224, 224)
|
||||
return dict(pixel_values=pixel_values)
|
||||
|
||||
|
||||
def data_gen_for_image_classification():
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([0])
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_masked_image_modeling():
|
||||
data = data_gen()
|
||||
num_patches = (config.image_size // config.patch_size)**2
|
||||
bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
||||
data['bool_masked_pos'] = bool_masked_pos
|
||||
return data
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# function to get the loss
|
||||
loss_fn_for_vit_model = lambda x: x.pooler_output.mean()
|
||||
loss_fn_for_image_classification = lambda x: x.logits.mean()
|
||||
loss_fn_for_masked_image_modeling = lambda x: x.loss
|
||||
|
||||
# register the following models
|
||||
# transformers.ViTModel,
|
||||
# transformers.ViTForMaskedImageModeling,
|
||||
# transformers.ViTForImageClassification,
|
||||
model_zoo.register(name='transformers_vit',
|
||||
model_fn=lambda: transformers.ViTModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_vit_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='transformers_vit_for_masked_image_modeling',
|
||||
model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
|
||||
data_gen_fn=data_gen_for_masked_image_modeling,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_masked_image_modeling,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='transformers_vit_for_image_classification',
|
||||
model_fn=lambda: transformers.ViTForImageClassification(config),
|
||||
data_gen_fn=data_gen_for_image_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_image_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
|
@ -1,9 +1,18 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
|
@ -12,44 +21,58 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output)
|
||||
|
||||
assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3)
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
||||
# check grad
|
||||
org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad
|
||||
shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'ViTModel':
|
||||
vit_model = org_model
|
||||
shard_vit_model = sharded_model
|
||||
else:
|
||||
vit_model = org_model.vit
|
||||
shard_vit_model = sharded_model.vit
|
||||
|
||||
# check attention grad
|
||||
org_grad = vit_model.encoder.layer[0].attention.attention.query.weight.grad
|
||||
shard_grad = shard_vit_model.encoder.layer[0].attention.attention.query.weight.grad
|
||||
shard_weight = shard_vit_model.encoder.layer[0].attention.attention.query.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_vit(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
run_vit_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_vit():
|
||||
spawn(check_vit, 4)
|
||||
spawn(check_vit, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_pipeline_model
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# TODO: add tests for forward/backward later
|
||||
pass
|
||||
|
||||
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_vit
|
||||
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
pixel_values = inputs['pixel_values']
|
||||
batch_size = len(pixel_values)
|
||||
hidden_size = 768
|
||||
hidden_state_shape = (batch_size, 197, hidden_size)
|
||||
|
||||
if not stage_manager.is_first_stage():
|
||||
# change inputs if not the first stage
|
||||
hidden_states = torch.randn(*hidden_state_shape).cuda()
|
||||
# inputs['pixel_values'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
sharded_model.train()
|
||||
|
||||
output = sharded_model(**inputs)
|
||||
if stage_manager.is_last_stage():
|
||||
if name != 'transformers_vit':
|
||||
assert output.loss is not None
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape, \
|
||||
f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}'
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_vit(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_vit_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_vit():
|
||||
spawn(check_vit, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_vit()
|
Loading…
Reference in New Issue