You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_zero_tensor_parallel/components.py

77 lines
1.7 KiB

import sys
from pathlib import Path
repo_path = Path(__file__).absolute().parents[2]
sys.path.append(str(repo_path))
try:
import model_zoo.vit.vision_transformer_from_config
except ImportError:
raise ImportError("model_zoo is not found, please check your path")
BATCH_SIZE = 8
IMG_SIZE = 32
PATCH_SIZE = 4
DIM = 512
NUM_ATTENTION_HEADS = 8
SUMMA_DIM = 2
NUM_CLASSES = 10
DEPTH = 6
model_cfg = dict(
type='VisionTransformerFromConfig',
tensor_splitting_cfg=dict(
type='ViTInputSplitter2D',
),
embedding_cfg=dict(
type='ViTPatchEmbedding2D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
),
token_fusion_cfg=dict(
type='ViTTokenFuser2D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
drop_rate=0.1
),
norm_cfg=dict(
type='LayerNorm2D',
normalized_shape=DIM,
eps=1e-6,
),
block_cfg=dict(
type='ViTBlock',
attention_cfg=dict(
type='ViTSelfAttention2D',
hidden_size=DIM,
num_attention_heads=NUM_ATTENTION_HEADS,
attention_dropout_prob=0.,
hidden_dropout_prob=0.1,
),
droppath_cfg=dict(
type='VanillaViTDropPath',
),
mlp_cfg=dict(
type='ViTMLP2D',
in_features=DIM,
dropout_prob=0.1,
mlp_ratio=1
),
norm_cfg=dict(
type='LayerNorm2D',
normalized_shape=DIM,
eps=1e-6,
),
),
head_cfg=dict(
type='ViTHead2D',
hidden_size=DIM,
num_classes=NUM_CLASSES,
),
embed_dim=DIM,
depth=DEPTH,
drop_path_rate=0.,
)