mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support pipeline base vit model (#4284)
* Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * support base vit pipeline * support vit downstream model * fix vit shard test * modify hidden states return type --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>pull/4445/head
parent
083d7da33d
commit
b3f5d7a3ba
|
@ -0,0 +1,337 @@
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
|
||||||
|
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
|
|
||||||
|
def _encoder_forward(
|
||||||
|
encoder: ViTEncoder,
|
||||||
|
start_idx: int,
|
||||||
|
end_idx: int,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
stage_manager: PipelineStageManager = None,
|
||||||
|
) -> Union[tuple, BaseModelOutput]:
|
||||||
|
|
||||||
|
for i in range(start_idx, end_idx):
|
||||||
|
layer_module = encoder.layer[i]
|
||||||
|
|
||||||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
|
||||||
|
if encoder.gradient_checkpointing and encoder.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs, False)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(layer_module),
|
||||||
|
hidden_states,
|
||||||
|
layer_head_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = layer_module(hidden_states, layer_head_mask, False)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
if not stage_manager.is_last_stage():
|
||||||
|
return hidden_states
|
||||||
|
else:
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(hidden_states)
|
||||||
|
return BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
|
||||||
|
|
||||||
|
from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling
|
||||||
|
|
||||||
|
def pp_forward(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||||
|
r"""
|
||||||
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
||||||
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (output_hidden_states
|
||||||
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if output_attentions is not None:
|
||||||
|
logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.')
|
||||||
|
output_attentions = None
|
||||||
|
if output_hidden_states is not None:
|
||||||
|
logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.')
|
||||||
|
output_hidden_states = None
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
if pixel_values is None:
|
||||||
|
raise ValueError("You have to specify pixel_values")
|
||||||
|
|
||||||
|
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
||||||
|
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
||||||
|
if pixel_values.dtype != expected_dtype:
|
||||||
|
pixel_values = pixel_values.to(expected_dtype)
|
||||||
|
|
||||||
|
embedding_output = self.embeddings(pixel_values,
|
||||||
|
bool_masked_pos=bool_masked_pos,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
|
else:
|
||||||
|
assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
|
||||||
|
|
||||||
|
# Go through encoder
|
||||||
|
if not stage_manager.is_last_stage():
|
||||||
|
hidden_states = _encoder_forward(
|
||||||
|
encoder=self.encoder,
|
||||||
|
start_idx=stage_index[0],
|
||||||
|
end_idx=stage_index[1],
|
||||||
|
hidden_states=embedding_output,
|
||||||
|
head_mask=head_mask,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
)
|
||||||
|
return {'hidden_states': hidden_states}
|
||||||
|
else:
|
||||||
|
encoder_outputs = _encoder_forward(
|
||||||
|
encoder=self.encoder,
|
||||||
|
start_idx=stage_index[0],
|
||||||
|
end_idx=stage_index[1],
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
head_mask=head_mask,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Go through rest layers
|
||||||
|
sequence_output = encoder_outputs[0]
|
||||||
|
sequence_output = self.layernorm(sequence_output)
|
||||||
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
||||||
|
return head_outputs + encoder_outputs[1:]
|
||||||
|
|
||||||
|
return BaseModelOutputWithPooling(
|
||||||
|
last_hidden_state=sequence_output,
|
||||||
|
pooler_output=pooled_output,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pp_forward
|
||||||
|
|
||||||
|
|
||||||
|
def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
|
||||||
|
|
||||||
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
from transformers.models.vit.modeling_vit import ImageClassifierOutput
|
||||||
|
|
||||||
|
def pp_forward(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> Union[tuple, ImageClassifierOutput]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if not stage_manager.is_first_stage():
|
||||||
|
assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
|
||||||
|
|
||||||
|
outputs = self.vit(
|
||||||
|
pixel_values,
|
||||||
|
head_mask=head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
return_dict=return_dict,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
# not last stage, return hidden_states
|
||||||
|
if not stage_manager.is_last_stage():
|
||||||
|
return outputs
|
||||||
|
else:
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
# last stage
|
||||||
|
logits = self.classifier(sequence_output[:, 0, :])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# move labels to correct device to enable model parallelism
|
||||||
|
labels = labels.to(logits.device)
|
||||||
|
if self.config.problem_type is None:
|
||||||
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
|
loss_fct = MSELoss()
|
||||||
|
if self.num_labels == 1:
|
||||||
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||||
|
else:
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
elif self.config.problem_type == "single_label_classification":
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return ImageClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pp_forward
|
||||||
|
|
||||||
|
|
||||||
|
def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.models.vit.modeling_vit import ImageClassifierOutput, MaskedImageModelingOutput
|
||||||
|
|
||||||
|
def pp_forward(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> Union[tuple, ImageClassifierOutput]:
|
||||||
|
r"""
|
||||||
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
||||||
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling
|
||||||
|
>>> import torch
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
||||||
|
>>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
|
||||||
|
|
||||||
|
>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
|
||||||
|
>>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
|
||||||
|
>>> # create random boolean mask of shape (batch_size, num_patches)
|
||||||
|
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
||||||
|
|
||||||
|
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
|
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
|
||||||
|
>>> list(reconstructed_pixel_values.shape)
|
||||||
|
[1, 3, 224, 224]
|
||||||
|
```"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride):
|
||||||
|
raise ValueError(
|
||||||
|
"When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that "
|
||||||
|
"the reconstructed image has the same dimensions as the input."
|
||||||
|
f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}.")
|
||||||
|
|
||||||
|
if not stage_manager.is_first_stage():
|
||||||
|
assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
|
||||||
|
|
||||||
|
outputs = self.vit(pixel_values,
|
||||||
|
bool_masked_pos=bool_masked_pos,
|
||||||
|
head_mask=head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
return_dict=return_dict,
|
||||||
|
hidden_states=hidden_states)
|
||||||
|
if not stage_manager.is_last_stage():
|
||||||
|
return outputs
|
||||||
|
else:
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
# Reshape to (batch_size, num_channels, height, width)
|
||||||
|
sequence_output = sequence_output[:, 1:]
|
||||||
|
batch_size, sequence_length, num_channels = sequence_output.shape
|
||||||
|
height = width = math.floor(sequence_length**0.5)
|
||||||
|
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
|
||||||
|
|
||||||
|
# Reconstruct pixel values
|
||||||
|
reconstructed_pixel_values = self.decoder(sequence_output)
|
||||||
|
|
||||||
|
masked_im_loss = None
|
||||||
|
if bool_masked_pos is not None:
|
||||||
|
size = self.config.image_size // self.config.patch_size
|
||||||
|
bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
|
||||||
|
mask = (bool_masked_pos.repeat_interleave(self.config.patch_size,
|
||||||
|
1).repeat_interleave(self.config.patch_size,
|
||||||
|
2).unsqueeze(1).contiguous())
|
||||||
|
reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
|
||||||
|
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (reconstructed_pixel_values,) + outputs[1:]
|
||||||
|
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
||||||
|
|
||||||
|
return MaskedImageModelingOutput(
|
||||||
|
loss=masked_im_loss,
|
||||||
|
reconstruction=reconstructed_pixel_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pp_forward
|
|
@ -75,6 +75,14 @@ _POLICY_LIST = {
|
||||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
|
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
|
||||||
PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"),
|
PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"),
|
||||||
|
|
||||||
|
# ViT
|
||||||
|
"transformers.models.vit.modeling_vit.ViTModel":
|
||||||
|
PolicyLocation(file_name="vit", class_name="ViTModelPolicy"),
|
||||||
|
"transformers.models.vit.modeling_vit.ViTForImageClassification":
|
||||||
|
PolicyLocation(file_name="vit", class_name="ViTForImageClassificationPolicy"),
|
||||||
|
"transformers.models.vit.modeling_vit.ViTForMaskedImageModeling":
|
||||||
|
PolicyLocation(file_name="vit", class_name="ViTForMaskedImageModelingPolicy"),
|
||||||
|
|
||||||
# OPT
|
# OPT
|
||||||
"transformers.models.opt.modeling_opt.OPTModel":
|
"transformers.models.opt.modeling_opt.OPTModel":
|
||||||
PolicyLocation(file_name="opt", class_name="OPTModelPolicy"),
|
PolicyLocation(file_name="opt", class_name="OPTModelPolicy"),
|
||||||
|
|
|
@ -1,12 +1,18 @@
|
||||||
from typing import Dict, Union
|
from functools import partial
|
||||||
|
from typing import Callable, Dict, List, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
|
from ..modeling.vit import (
|
||||||
|
ViTForImageClassification_pipeline_forward,
|
||||||
|
ViTForMaskedImageModeling_pipeline_forward,
|
||||||
|
ViTModel_pipeline_forward,
|
||||||
|
)
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ['ViTPolicy']
|
__all__ = ['ViTPolicy', 'ViTModelPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy']
|
||||||
|
|
||||||
|
|
||||||
class ViTPolicy(Policy):
|
class ViTPolicy(Policy):
|
||||||
|
@ -15,96 +21,203 @@ class ViTPolicy(Policy):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def preprocess(self):
|
def preprocess(self):
|
||||||
# Resize embedding
|
|
||||||
vocab_size = self.model.config.vocab_size
|
|
||||||
world_size = self.shard_config.tensor_parallel_size
|
|
||||||
|
|
||||||
if vocab_size % world_size != 0:
|
|
||||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
|
||||||
self.model.resize_token_embeddings(new_vocab_size)
|
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
|
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
|
||||||
|
|
||||||
base_policy = {
|
policy = {}
|
||||||
ViTEmbeddings:
|
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=DropoutForReplicatedInput,
|
target_module=col_nn.DropoutForReplicatedInput,
|
||||||
)
|
)
|
||||||
]),
|
])
|
||||||
ViTLayer:
|
|
||||||
ModulePolicyDescription(attribute_replacement={
|
policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={
|
||||||
"attention.attention.num_attention_heads":
|
"attention.attention.num_attention_heads":
|
||||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
"attention.attention.all_head_size":
|
"attention.attention.all_head_size":
|
||||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
},
|
},
|
||||||
|
param_replacement=[],
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.attention.query",
|
suffix="attention.attention.query",
|
||||||
target_module=Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.attention.key",
|
suffix="attention.attention.key",
|
||||||
target_module=Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.attention.value",
|
suffix="attention.attention.value",
|
||||||
target_module=Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.attention.dropout",
|
suffix="attention.attention.dropout",
|
||||||
target_module=DropoutForParallelInput,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.dense",
|
suffix="attention.output.dense",
|
||||||
target_module=Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.dropout",
|
suffix="attention.output.dropout",
|
||||||
target_module=DropoutForParallelInput,
|
target_module=col_nn.DropoutForReplicatedInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="intermediate.dense",
|
suffix="intermediate.dense",
|
||||||
target_module=Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.dense",
|
suffix="output.dense",
|
||||||
target_module=Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.dropout",
|
suffix="output.dropout",
|
||||||
target_module=DropoutForParallelInput,
|
target_module=col_nn.DropoutForReplicatedInput,
|
||||||
),
|
),
|
||||||
]),
|
|
||||||
}
|
|
||||||
|
|
||||||
# optimization configuration
|
|
||||||
if self.shard_config.enable_fused_normalization:
|
|
||||||
base_policy[ViTAttention].sub_module_replacement.extend([
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="layernorm_before",
|
|
||||||
target_module=FusedLayerNorm,
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="layernorm_after",
|
|
||||||
target_module=FusedLayerNorm,
|
|
||||||
)
|
|
||||||
])
|
])
|
||||||
base_policy[ViTModel].sub_module_replacement.append(
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="layernorm",
|
|
||||||
target_module=FusedLayerNorm,
|
|
||||||
))
|
|
||||||
|
|
||||||
return base_policy
|
return policy
|
||||||
|
|
||||||
def new_model_class(self):
|
def new_model_class(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
||||||
|
|
||||||
|
if self.model.__class__.__name__ == 'ViTModel':
|
||||||
|
module = self.model
|
||||||
|
else:
|
||||||
|
module = self.model.vit
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
|
held_layers = []
|
||||||
|
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
held_layers.append(module.embeddings)
|
||||||
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict):
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
if self.model.__class__.__name__ == 'ViTModel':
|
||||||
|
module = self.model
|
||||||
|
else:
|
||||||
|
module = self.model.vit
|
||||||
|
|
||||||
|
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
||||||
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
method_replacement = {'forward': pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)}
|
||||||
|
self.append_or_create_method_replacement(description=method_replacement,
|
||||||
|
policy=policy,
|
||||||
|
target_key=model_cls)
|
||||||
|
|
||||||
|
|
||||||
|
# ViTModel
|
||||||
|
class ViTModelPolicy(ViTPolicy):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers.models.vit.modeling_vit import ViTModel
|
||||||
|
|
||||||
|
policy = super().module_policy()
|
||||||
|
|
||||||
|
if self.shard_config.pipeline_stage_manager is not None:
|
||||||
|
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
||||||
|
|
||||||
|
module = self.model
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.layernorm)
|
||||||
|
held_layers.append(module.pooler)
|
||||||
|
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
|
||||||
|
# ViTForImageClassification
|
||||||
|
class ViTForImageClassificationPolicy(ViTPolicy):
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel
|
||||||
|
|
||||||
|
policy = super().module_policy()
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
new_item = {
|
||||||
|
ViTForImageClassification:
|
||||||
|
ModulePolicyDescription(sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
|
||||||
|
])
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
|
|
||||||
|
if self.shard_config.pipeline_stage_manager is not None:
|
||||||
|
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
|
||||||
|
self.set_pipeline_forward(model_cls=ViTForImageClassification,
|
||||||
|
pipeline_forward=ViTForImageClassification_pipeline_forward,
|
||||||
|
policy=policy)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
||||||
|
|
||||||
|
module = self.model.vit
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.layernorm)
|
||||||
|
held_layers.append(self.model.classifier)
|
||||||
|
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
|
||||||
|
# ViTForMaskedImageModeling
|
||||||
|
class ViTForMaskedImageModelingPolicy(ViTPolicy):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel
|
||||||
|
|
||||||
|
policy = super().module_policy()
|
||||||
|
|
||||||
|
if self.shard_config.pipeline_stage_manager is not None:
|
||||||
|
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
|
||||||
|
self.set_pipeline_forward(model_cls=ViTForMaskedImageModeling,
|
||||||
|
pipeline_forward=ViTForMaskedImageModeling_pipeline_forward,
|
||||||
|
policy=policy)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
||||||
|
|
||||||
|
module = self.model.vit
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.layernorm)
|
||||||
|
held_layers.append(self.model.decoder)
|
||||||
|
|
||||||
|
return held_layers
|
||||||
|
|
|
@ -5,3 +5,4 @@ from .gpt import *
|
||||||
from .llama import *
|
from .llama import *
|
||||||
from .opt import *
|
from .opt import *
|
||||||
from .t5 import *
|
from .t5 import *
|
||||||
|
from .vit import *
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from ..registry import ModelAttribute, model_zoo
|
||||||
|
|
||||||
|
# ===============================
|
||||||
|
# Register single-sentence VIT
|
||||||
|
# ===============================
|
||||||
|
|
||||||
|
config = transformers.ViTConfig(
|
||||||
|
num_hidden_layers=4,
|
||||||
|
# hidden_size=128,
|
||||||
|
# intermediate_size=256,
|
||||||
|
num_attention_heads=4)
|
||||||
|
|
||||||
|
|
||||||
|
# define data gen function
|
||||||
|
def data_gen():
|
||||||
|
pixel_values = torch.randn(1, 3, 224, 224)
|
||||||
|
return dict(pixel_values=pixel_values)
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_for_image_classification():
|
||||||
|
data = data_gen()
|
||||||
|
data['labels'] = torch.tensor([0])
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_for_masked_image_modeling():
|
||||||
|
data = data_gen()
|
||||||
|
num_patches = (config.image_size // config.patch_size)**2
|
||||||
|
bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
||||||
|
data['bool_masked_pos'] = bool_masked_pos
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# define output transform function
|
||||||
|
output_transform_fn = lambda x: x
|
||||||
|
|
||||||
|
# function to get the loss
|
||||||
|
loss_fn_for_vit_model = lambda x: x.pooler_output.mean()
|
||||||
|
loss_fn_for_image_classification = lambda x: x.logits.mean()
|
||||||
|
loss_fn_for_masked_image_modeling = lambda x: x.loss
|
||||||
|
|
||||||
|
# register the following models
|
||||||
|
# transformers.ViTModel,
|
||||||
|
# transformers.ViTForMaskedImageModeling,
|
||||||
|
# transformers.ViTForImageClassification,
|
||||||
|
model_zoo.register(name='transformers_vit',
|
||||||
|
model_fn=lambda: transformers.ViTModel(config),
|
||||||
|
data_gen_fn=data_gen,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_vit_model,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True))
|
||||||
|
|
||||||
|
model_zoo.register(name='transformers_vit_for_masked_image_modeling',
|
||||||
|
model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
|
||||||
|
data_gen_fn=data_gen_for_masked_image_modeling,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_masked_image_modeling,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True))
|
||||||
|
|
||||||
|
model_zoo.register(name='transformers_vit_for_image_classification',
|
||||||
|
model_fn=lambda: transformers.ViTForImageClassification(config),
|
||||||
|
data_gen_fn=data_gen_for_image_classification,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_image_classification,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True))
|
|
@ -1,9 +1,18 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||||
|
from colossalai.testing import (
|
||||||
|
assert_hf_output_close,
|
||||||
|
clear_cache_before_run,
|
||||||
|
parameterize,
|
||||||
|
rerun_if_address_is_in_use,
|
||||||
|
spawn,
|
||||||
|
)
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||||
|
|
||||||
|
@ -12,44 +21,58 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||||
# check forward
|
# check forward
|
||||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||||
output_transform_fn, loss_fn)
|
output_transform_fn, loss_fn)
|
||||||
assert_hf_output_close(org_output, shard_output)
|
assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
# do backward
|
# do backward
|
||||||
org_loss.backward()
|
org_loss.backward()
|
||||||
shard_loss.backward()
|
shard_loss.backward()
|
||||||
|
|
||||||
# check grad
|
assert torch.allclose(org_loss, shard_loss,
|
||||||
org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad
|
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||||
shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad
|
|
||||||
|
|
||||||
|
# unwrap model
|
||||||
|
if org_model.__class__.__name__ == 'ViTModel':
|
||||||
|
vit_model = org_model
|
||||||
|
shard_vit_model = sharded_model
|
||||||
|
else:
|
||||||
|
vit_model = org_model.vit
|
||||||
|
shard_vit_model = sharded_model.vit
|
||||||
|
|
||||||
|
# check attention grad
|
||||||
|
org_grad = vit_model.encoder.layer[0].attention.attention.query.weight.grad
|
||||||
|
shard_grad = shard_vit_model.encoder.layer[0].attention.attention.query.weight.grad
|
||||||
|
shard_weight = shard_vit_model.encoder.layer[0].attention.attention.query.weight
|
||||||
|
|
||||||
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
else:
|
||||||
assert torch.allclose(org_loss, shard_loss,
|
all_shard_grad = shard_grad
|
||||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('enable_fused_normalization', [True, False])
|
||||||
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
|
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||||
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def check_vit(rank, world_size, port):
|
def check_vit(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_vit_test()
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
||||||
org_model, sharded_model = build_model(world_size, model_fn)
|
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.skip
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def test_vit():
|
def test_vit():
|
||||||
spawn(check_vit, 4)
|
spawn(check_vit, 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import build_pipeline_model
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||||
|
# TODO: add tests for forward/backward later
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('enable_tensor_parallelism', [False])
|
||||||
|
@parameterize('enable_fused_normalization', [False])
|
||||||
|
@parameterize('use_lazy_init', [False])
|
||||||
|
#TODO: merge this into test_shard_vit
|
||||||
|
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||||
|
DP_DIM, PP_DIM = 0, 1
|
||||||
|
DP_SIZE, PP_SIZE = 2, 2
|
||||||
|
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||||
|
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||||
|
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||||
|
|
||||||
|
inputs = data_gen_fn()
|
||||||
|
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||||
|
pixel_values = inputs['pixel_values']
|
||||||
|
batch_size = len(pixel_values)
|
||||||
|
hidden_size = 768
|
||||||
|
hidden_state_shape = (batch_size, 197, hidden_size)
|
||||||
|
|
||||||
|
if not stage_manager.is_first_stage():
|
||||||
|
# change inputs if not the first stage
|
||||||
|
hidden_states = torch.randn(*hidden_state_shape).cuda()
|
||||||
|
# inputs['pixel_values'] = None
|
||||||
|
inputs['hidden_states'] = hidden_states
|
||||||
|
|
||||||
|
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||||
|
enable_tensor_parallelism, use_lazy_init)
|
||||||
|
sharded_model.train()
|
||||||
|
|
||||||
|
output = sharded_model(**inputs)
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
if name != 'transformers_vit':
|
||||||
|
assert output.loss is not None
|
||||||
|
else:
|
||||||
|
assert output['hidden_states'].shape == hidden_state_shape, \
|
||||||
|
f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}'
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def check_vit(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_vit_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_vit():
|
||||||
|
spawn(check_vit, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_vit()
|
Loading…
Reference in New Issue