[shardformer] support pipeline base vit model (#4284)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* support base vit pipeline

* support vit downstream model

* fix vit shard test

* modify hidden states return type

---------

Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
pull/4445/head
FoolPlayer 2023-07-25 15:02:29 +08:00 committed by Hongxin Liu
parent 083d7da33d
commit b3f5d7a3ba
7 changed files with 728 additions and 104 deletions

View File

@ -0,0 +1,337 @@
import logging
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
from colossalai.pipeline.stage_manager import PipelineStageManager
def _encoder_forward(
encoder: ViTEncoder,
start_idx: int,
end_idx: int,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
stage_manager: PipelineStageManager = None,
) -> Union[tuple, BaseModelOutput]:
for i in range(start_idx, end_idx):
layer_module = encoder.layer[i]
layer_head_mask = head_mask[i] if head_mask is not None else None
if encoder.gradient_checkpointing and encoder.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, False)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
layer_head_mask,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, False)
hidden_states = layer_outputs[0]
if not stage_manager.is_last_stage():
return hidden_states
else:
if not return_dict:
return tuple(hidden_states)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=None,
attentions=None,
)
def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling
def pp_forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions is not None:
logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.')
output_attentions = None
if output_hidden_states is not None:
logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.')
output_hidden_states = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if stage_manager.is_first_stage():
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(pixel_values,
bool_masked_pos=bool_masked_pos,
interpolate_pos_encoding=interpolate_pos_encoding)
else:
assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
# Go through encoder
if not stage_manager.is_last_stage():
hidden_states = _encoder_forward(
encoder=self.encoder,
start_idx=stage_index[0],
end_idx=stage_index[1],
hidden_states=embedding_output,
head_mask=head_mask,
return_dict=return_dict,
stage_manager=stage_manager,
)
return {'hidden_states': hidden_states}
else:
encoder_outputs = _encoder_forward(
encoder=self.encoder,
start_idx=stage_index[0],
end_idx=stage_index[1],
hidden_states=hidden_states,
head_mask=head_mask,
return_dict=return_dict,
stage_manager=stage_manager,
)
# Go through rest layers
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
return pp_forward
def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.models.vit.modeling_vit import ImageClassifierOutput
def pp_forward(
self,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.FloatTensor] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if not stage_manager.is_first_stage():
assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
outputs = self.vit(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
hidden_states=hidden_states,
)
# not last stage, return hidden_states
if not stage_manager.is_last_stage():
return outputs
else:
sequence_output = outputs[0]
# last stage
logits = self.classifier(sequence_output[:, 0, :])
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return pp_forward
def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
import math
import torch.nn as nn
from transformers.models.vit.modeling_vit import ImageClassifierOutput, MaskedImageModelingOutput
def pp_forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.FloatTensor] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
>>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
>>> # create random boolean mask of shape (batch_size, num_patches)
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape)
[1, 3, 224, 224]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride):
raise ValueError(
"When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that "
"the reconstructed image has the same dimensions as the input."
f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}.")
if not stage_manager.is_first_stage():
assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
outputs = self.vit(pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
hidden_states=hidden_states)
if not stage_manager.is_last_stage():
return outputs
else:
sequence_output = outputs[0]
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output[:, 1:]
batch_size, sequence_length, num_channels = sequence_output.shape
height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
# Reconstruct pixel values
reconstructed_pixel_values = self.decoder(sequence_output)
masked_im_loss = None
if bool_masked_pos is not None:
size = self.config.image_size // self.config.patch_size
bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
mask = (bool_masked_pos.repeat_interleave(self.config.patch_size,
1).repeat_interleave(self.config.patch_size,
2).unsqueeze(1).contiguous())
reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
if not return_dict:
output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedImageModelingOutput(
loss=masked_im_loss,
reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return pp_forward

View File

@ -75,6 +75,14 @@ _POLICY_LIST = {
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"),
# ViT
"transformers.models.vit.modeling_vit.ViTModel":
PolicyLocation(file_name="vit", class_name="ViTModelPolicy"),
"transformers.models.vit.modeling_vit.ViTForImageClassification":
PolicyLocation(file_name="vit", class_name="ViTForImageClassificationPolicy"),
"transformers.models.vit.modeling_vit.ViTForMaskedImageModeling":
PolicyLocation(file_name="vit", class_name="ViTForMaskedImageModelingPolicy"),
# OPT
"transformers.models.opt.modeling_opt.OPTModel":
PolicyLocation(file_name="opt", class_name="OPTModelPolicy"),

View File

@ -1,12 +1,18 @@
from typing import Dict, Union
from functools import partial
from typing import Callable, Dict, List, Union
import torch.nn as nn
from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
import colossalai.shardformer.layer as col_nn
from ..modeling.vit import (
ViTForImageClassification_pipeline_forward,
ViTForMaskedImageModeling_pipeline_forward,
ViTModel_pipeline_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['ViTPolicy']
__all__ = ['ViTPolicy', 'ViTModelPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy']
class ViTPolicy(Policy):
@ -15,96 +21,203 @@ class ViTPolicy(Policy):
pass
def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
base_policy = {
ViTEmbeddings:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForReplicatedInput,
)
]),
ViTLayer:
ModulePolicyDescription(attribute_replacement={
"attention.attention.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=DropoutForParallelInput,
),
]),
}
policy = {}
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[ViTAttention].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="layernorm_before",
target_module=FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="layernorm_after",
target_module=FusedLayerNorm,
)
])
base_policy[ViTModel].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="layernorm",
target_module=FusedLayerNorm,
))
if self.shard_config.enable_tensor_parallelism:
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
)
])
return base_policy
policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={
"attention.attention.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
])
return policy
def new_model_class(self):
return None
def postprocess(self):
return self.model
def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
if self.model.__class__.__name__ == 'ViTModel':
module = self.model
else:
module = self.model.vit
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx])
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict):
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'ViTModel':
module = self.model
else:
module = self.model.vit
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)
# ViTModel
class ViTModelPolicy(ViTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.vit.modeling_vit import ViTModel
policy = super().module_policy()
if self.shard_config.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
held_layers.append(module.layernorm)
held_layers.append(module.pooler)
return held_layers
# ViTForImageClassification
class ViTForImageClassificationPolicy(ViTPolicy):
def module_policy(self):
from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
new_item = {
ViTForImageClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
if self.shard_config.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
self.set_pipeline_forward(model_cls=ViTForImageClassification,
pipeline_forward=ViTForImageClassification_pipeline_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
module = self.model.vit
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
held_layers.append(module.layernorm)
held_layers.append(self.model.classifier)
return held_layers
# ViTForMaskedImageModeling
class ViTForMaskedImageModelingPolicy(ViTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel
policy = super().module_policy()
if self.shard_config.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
self.set_pipeline_forward(model_cls=ViTForMaskedImageModeling,
pipeline_forward=ViTForMaskedImageModeling_pipeline_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
module = self.model.vit
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
held_layers.append(module.layernorm)
held_layers.append(self.model.decoder)
return held_layers

View File

@ -5,3 +5,4 @@ from .gpt import *
from .llama import *
from .opt import *
from .t5 import *
from .vit import *

View File

@ -0,0 +1,68 @@
import torch
import transformers
from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence VIT
# ===============================
config = transformers.ViTConfig(
num_hidden_layers=4,
# hidden_size=128,
# intermediate_size=256,
num_attention_heads=4)
# define data gen function
def data_gen():
pixel_values = torch.randn(1, 3, 224, 224)
return dict(pixel_values=pixel_values)
def data_gen_for_image_classification():
data = data_gen()
data['labels'] = torch.tensor([0])
return data
def data_gen_for_masked_image_modeling():
data = data_gen()
num_patches = (config.image_size // config.patch_size)**2
bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
data['bool_masked_pos'] = bool_masked_pos
return data
# define output transform function
output_transform_fn = lambda x: x
# function to get the loss
loss_fn_for_vit_model = lambda x: x.pooler_output.mean()
loss_fn_for_image_classification = lambda x: x.logits.mean()
loss_fn_for_masked_image_modeling = lambda x: x.loss
# register the following models
# transformers.ViTModel,
# transformers.ViTForMaskedImageModeling,
# transformers.ViTForImageClassification,
model_zoo.register(name='transformers_vit',
model_fn=lambda: transformers.ViTModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_vit_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_vit_for_masked_image_modeling',
model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
data_gen_fn=data_gen_for_masked_image_modeling,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_masked_image_modeling,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_vit_for_image_classification',
model_fn=lambda: transformers.ViTForImageClassification(config),
data_gen_fn=data_gen_for_image_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_image_classification,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -1,9 +1,18 @@
import os
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
@ -12,44 +21,58 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output)
assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3)
# do backward
org_loss.backward()
shard_loss.backward()
# check grad
org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad
shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
# unwrap model
if org_model.__class__.__name__ == 'ViTModel':
vit_model = org_model
shard_vit_model = sharded_model
else:
vit_model = org_model.vit
shard_vit_model = sharded_model.vit
# check attention grad
org_grad = vit_model.encoder.layer[0].attention.attention.query.weight.grad
shard_grad = shard_vit_model.encoder.layer[0].attention.attention.query.weight.grad
shard_weight = shard_vit_model.encoder.layer[0].attention.attention.query.weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
def check_vit(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
run_vit_test()
@pytest.mark.dist
@pytest.mark.skip
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_vit():
spawn(check_vit, 4)
spawn(check_vit, 2)
if __name__ == "__main__":

View File

@ -0,0 +1,74 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_pipeline_model
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# TODO: add tests for forward/backward later
pass
@parameterize('enable_tensor_parallelism', [False])
@parameterize('enable_fused_normalization', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_vit
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
pixel_values = inputs['pixel_values']
batch_size = len(pixel_values)
hidden_size = 768
hidden_state_shape = (batch_size, 197, hidden_size)
if not stage_manager.is_first_stage():
# change inputs if not the first stage
hidden_states = torch.randn(*hidden_state_shape).cuda()
# inputs['pixel_values'] = None
inputs['hidden_states'] = hidden_states
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
sharded_model.train()
output = sharded_model(**inputs)
if stage_manager.is_last_stage():
if name != 'transformers_vit':
assert output.loss is not None
else:
assert output['hidden_states'].shape == hidden_state_shape, \
f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}'
torch.cuda.empty_cache()
def check_vit(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_vit_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_vit():
spawn(check_vit, 4)
if __name__ == "__main__":
test_vit()