mirror of https://github.com/hpcaitech/ColossalAI
[example] add vit (#1942)
* [ColoTensor] ColoInitContext initialize parameters in shard mode. * polish * [example] add vitpull/1944/head
parent
c7925c5d08
commit
cf68cc92ac
@ -0,0 +1,61 @@
|
|||||||
|
# Vision Transformer with ColoTensor
|
||||||
|
|
||||||
|
# Overview
|
||||||
|
|
||||||
|
In this example, we will run Vision Transformer with ColoTensor.
|
||||||
|
|
||||||
|
We use model **ViTForImageClassification** from Hugging Face [Link](https://huggingface.co/docs/transformers/model_doc/vit) for unit test.
|
||||||
|
You can change world size or decide whether use DDP in our code.
|
||||||
|
|
||||||
|
We use model **vision_transformer** from timm [Link](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) for training example.
|
||||||
|
|
||||||
|
(2022/6/28) The default configuration now supports 2DP+2TP with gradient accumulation and checkpoint support. Zero is not supported at present.
|
||||||
|
|
||||||
|
# Requirement
|
||||||
|
|
||||||
|
You should install colossalai from main branch with commit 561e904.
|
||||||
|
|
||||||
|
## Unit test
|
||||||
|
To run unit test, you should install pytest, transformers with:
|
||||||
|
```shell
|
||||||
|
pip install pytest transformers
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training example
|
||||||
|
To run training example with ViT-S, you should install **NVIDIA DALI** from [Link](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html) for dataloader support.
|
||||||
|
You also need to install timm and titans for model/dataloader support with:
|
||||||
|
```shell
|
||||||
|
pip install timm titans
|
||||||
|
```
|
||||||
|
|
||||||
|
### Data preparation
|
||||||
|
You can download the ImageNet dataset from the [ImageNet official website](https://www.image-net.org/download.php). You should get the raw images after downloading the dataset. As we use **NVIDIA DALI** to read data, we use the TFRecords dataset instead of raw Imagenet dataset. This offers better speedup to IO. If you don't have TFRecords dataset, follow [imagenet-tools](https://github.com/ver217/imagenet-tools) to build one.
|
||||||
|
|
||||||
|
Before you start training, you need to set the environment variable `DATA` so that the script knows where to fetch the data for DALI dataloader.
|
||||||
|
```shell
|
||||||
|
export DATA=/path/to/ILSVRC2012
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
# How to run
|
||||||
|
|
||||||
|
## Unit test
|
||||||
|
In your terminal
|
||||||
|
```shell
|
||||||
|
pytest test_vit.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This will evaluate models with different **world_size** and **use_ddp**.
|
||||||
|
|
||||||
|
## Training example
|
||||||
|
Modify the settings in run.sh according to your environment.
|
||||||
|
For example, if you set `--nproc_per_node=8` in `run.sh` and `TP_WORLD_SIZE=2` in your config file,
|
||||||
|
data parallel size will be automatically calculated as 4.
|
||||||
|
Thus, the parallel strategy is set to 4DP+2TP.
|
||||||
|
|
||||||
|
Then in your terminal
|
||||||
|
```shell
|
||||||
|
sh run.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
This will start ViT-S training with ImageNet.
|
@ -0,0 +1,32 @@
|
|||||||
|
from colossalai.amp import AMP_TYPE
|
||||||
|
|
||||||
|
# hyperparameters
|
||||||
|
# BATCH_SIZE is as per GPU
|
||||||
|
# global batch size = BATCH_SIZE x data parallel size
|
||||||
|
BATCH_SIZE = 256
|
||||||
|
LEARNING_RATE = 3e-3
|
||||||
|
WEIGHT_DECAY = 0.3
|
||||||
|
NUM_EPOCHS = 300
|
||||||
|
WARMUP_EPOCHS = 32
|
||||||
|
|
||||||
|
# model config
|
||||||
|
IMG_SIZE = 224
|
||||||
|
PATCH_SIZE = 16
|
||||||
|
HIDDEN_SIZE = 384
|
||||||
|
DEPTH = 12
|
||||||
|
NUM_HEADS = 6
|
||||||
|
MLP_RATIO = 4
|
||||||
|
NUM_CLASSES = 1000
|
||||||
|
CHECKPOINT = False
|
||||||
|
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
||||||
|
|
||||||
|
USE_DDP = True
|
||||||
|
TP_WORLD_SIZE = 2
|
||||||
|
TP_TYPE = 'row'
|
||||||
|
parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),)
|
||||||
|
|
||||||
|
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||||
|
clip_grad_norm = 1.0
|
||||||
|
gradient_accumulation = 8
|
||||||
|
|
||||||
|
LOG_PATH = "./log"
|
@ -0,0 +1,15 @@
|
|||||||
|
export DATA=/data/scratch/imagenet/tf_records
|
||||||
|
export OMP_NUM_THREADS=4
|
||||||
|
|
||||||
|
# resume
|
||||||
|
# CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \
|
||||||
|
# --nproc_per_node 4 train.py \
|
||||||
|
# --config configs/vit_1d_tp2.py \
|
||||||
|
# --resume_from checkpoint/epoch_10 \
|
||||||
|
# --master_port 29598 | tee ./out 2>&1
|
||||||
|
|
||||||
|
# train
|
||||||
|
CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \
|
||||||
|
--nproc_per_node 4 train.py \
|
||||||
|
--config configs/vit_1d_tp2.py \
|
||||||
|
--master_port 29598 | tee ./out 2>&1
|
@ -0,0 +1,132 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from utils.util import set_seed, tensor_equal, tensor_shard_equal
|
||||||
|
from vit import get_training_components
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||||
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
|
||||||
|
|
||||||
|
# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating.
|
||||||
|
# But for other layers, it's 1d_col split.
|
||||||
|
# Layernorm is not supported for now.
|
||||||
|
# patch_embeddings.projection has nn.Conv2d
|
||||||
|
# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182
|
||||||
|
def init_1d_row_for_linear_weight_spec(model, world_size: int):
|
||||||
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
with DistSpecManager.no_grad():
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n:
|
||||||
|
p.set_process_group(pg)
|
||||||
|
p.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
||||||
|
# Similarly, it's col split for Linear but row split for others.
|
||||||
|
def init_1d_col_for_linear_weight_bias_spec(model, world_size: int):
|
||||||
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
with DistSpecManager.no_grad():
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if ('weight' in n
|
||||||
|
or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n:
|
||||||
|
p.set_process_group(pg)
|
||||||
|
p.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
||||||
|
def check_param_equal(model, torch_model):
|
||||||
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||||
|
assert tensor_shard_equal(torch_p, p)
|
||||||
|
|
||||||
|
|
||||||
|
def check_grad_equal(model, torch_model):
|
||||||
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||||
|
if (torch_p.grad.shape == p.grad.shape):
|
||||||
|
assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True
|
||||||
|
else:
|
||||||
|
dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape))
|
||||||
|
dim = dims_not_eq.item()
|
||||||
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||||
|
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||||
|
assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True
|
||||||
|
|
||||||
|
|
||||||
|
def run_vit(init_spec_func, use_ddp):
|
||||||
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components()
|
||||||
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
model = model_builder()
|
||||||
|
model = model.cuda()
|
||||||
|
torch_model = model_builder().cuda()
|
||||||
|
if use_ddp:
|
||||||
|
model = ColoDDP(model)
|
||||||
|
torch_model = DDP(torch_model,
|
||||||
|
device_ids=[gpc.get_global_rank()],
|
||||||
|
process_group=gpc.get_group(ParallelMode.DATA))
|
||||||
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||||
|
torch_p.data.copy_(p)
|
||||||
|
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
init_spec_func(model, world_size)
|
||||||
|
|
||||||
|
check_param_equal(model, torch_model)
|
||||||
|
model.train()
|
||||||
|
torch_model.train()
|
||||||
|
set_seed(gpc.get_local_rank(ParallelMode.DATA))
|
||||||
|
|
||||||
|
optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||||
|
torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||||
|
|
||||||
|
for i, image_dict in enumerate(train_dataloader):
|
||||||
|
if use_ddp:
|
||||||
|
model.zero_grad()
|
||||||
|
else:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
logits = model(image_dict['pixel_values'])
|
||||||
|
torch_logits = torch_model(image_dict['pixel_values'])
|
||||||
|
assert tensor_equal(torch_logits.logits, logits.logits)
|
||||||
|
loss = criterion(logits.logits, image_dict['label'])
|
||||||
|
torch_loss = criterion(torch_logits.logits, image_dict['label'])
|
||||||
|
if use_ddp:
|
||||||
|
model.backward(loss)
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
torch_loss.backward()
|
||||||
|
check_grad_equal(model, torch_model)
|
||||||
|
optimizer.step()
|
||||||
|
torch_optimizer.step()
|
||||||
|
check_param_equal(model, torch_model)
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port, use_ddp):
|
||||||
|
if use_ddp and world_size == 1:
|
||||||
|
return
|
||||||
|
tp_world_size = world_size // 2 if use_ddp else world_size
|
||||||
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
||||||
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_vit(init_1d_row_for_linear_weight_spec, use_ddp)
|
||||||
|
run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
|
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_vit(world_size, use_ddp):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_vit(1, False)
|
@ -0,0 +1,161 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from timm.models.vision_transformer import _create_vision_transformer
|
||||||
|
from titans.dataloader.imagenet import build_dali_imagenet
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
|
from colossalai.nn import CrossEntropyLoss
|
||||||
|
from colossalai.nn._ops import *
|
||||||
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||||
|
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
|
||||||
|
|
||||||
|
def init_1d_row_for_linear_weight_spec(model, world_size: int):
|
||||||
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
with DistSpecManager.no_grad():
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if 'weight' in n and 'norm' not in n and 'patch_embed.proj.weight' not in n:
|
||||||
|
p.set_process_group(pg)
|
||||||
|
p.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
||||||
|
# Similarly, it's col split for Linear but row split for others.
|
||||||
|
def init_1d_col_for_linear_weight_bias_spec(model, world_size: int):
|
||||||
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
with DistSpecManager.no_grad():
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if ('weight' in n or 'bias' in n) and 'norm' not in n and ('patch_embed.proj.weight' not in n
|
||||||
|
and 'patch_embed.proj.bias' not in n):
|
||||||
|
p.set_process_group(pg)
|
||||||
|
p.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
||||||
|
def init_spec_func(model, tp_type):
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
if tp_type == 'row':
|
||||||
|
init_1d_row_for_linear_weight_spec(model, world_size)
|
||||||
|
elif tp_type == 'col':
|
||||||
|
init_1d_col_for_linear_weight_bias_spec(model, world_size)
|
||||||
|
else:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
def train_imagenet():
|
||||||
|
|
||||||
|
parser = colossalai.get_default_parser()
|
||||||
|
parser.add_argument('--from_torch', default=True, action='store_true')
|
||||||
|
parser.add_argument('--resume_from', default=False)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
colossalai.launch_from_torch(config=args.config)
|
||||||
|
use_ddp = gpc.config.USE_DDP
|
||||||
|
|
||||||
|
disable_existing_loggers()
|
||||||
|
|
||||||
|
logger = get_dist_logger()
|
||||||
|
if hasattr(gpc.config, 'LOG_PATH'):
|
||||||
|
if gpc.get_global_rank() == 0:
|
||||||
|
log_path = gpc.config.LOG_PATH
|
||||||
|
if not os.path.exists(log_path):
|
||||||
|
os.mkdir(log_path)
|
||||||
|
logger.log_to_file(log_path)
|
||||||
|
|
||||||
|
logger.info('Build data loader', ranks=[0])
|
||||||
|
root = os.environ['DATA']
|
||||||
|
train_dataloader, test_dataloader = build_dali_imagenet(root,
|
||||||
|
train_batch_size=gpc.config.BATCH_SIZE,
|
||||||
|
test_batch_size=gpc.config.BATCH_SIZE)
|
||||||
|
|
||||||
|
logger.info('Build model', ranks=[0])
|
||||||
|
|
||||||
|
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
|
||||||
|
patch_size=gpc.config.PATCH_SIZE,
|
||||||
|
embed_dim=gpc.config.HIDDEN_SIZE,
|
||||||
|
depth=gpc.config.DEPTH,
|
||||||
|
num_heads=gpc.config.NUM_HEADS,
|
||||||
|
mlp_ratio=gpc.config.MLP_RATIO,
|
||||||
|
num_classes=gpc.config.NUM_CLASSES,
|
||||||
|
drop_rate=0.1,
|
||||||
|
attn_drop_rate=0.1,
|
||||||
|
weight_init='jax')
|
||||||
|
|
||||||
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
model = _create_vision_transformer('vit_small_patch16_224', pretrained=False, **model_kwargs)
|
||||||
|
init_spec_func(model, gpc.config.TP_TYPE)
|
||||||
|
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
model = ColoDDP(module=model, process_group=ProcessGroup(tp_degree=world_size))
|
||||||
|
logger.info('Build criterion, optimizer, lr_scheduler', ranks=[0])
|
||||||
|
optimizer = HybridAdam(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||||
|
|
||||||
|
criterion = CrossEntropyLoss()
|
||||||
|
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||||
|
total_steps=gpc.config.NUM_EPOCHS,
|
||||||
|
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||||
|
|
||||||
|
start_epoch = 0
|
||||||
|
if args.resume_from:
|
||||||
|
load_model = torch.load(args.resume_from + '_model.pth')
|
||||||
|
start_epoch = load_model['epoch']
|
||||||
|
model.load_state_dict(load_model['model'])
|
||||||
|
load_optim = torch.load(args.resume_from + '_optim_rank_{}.pth'.format(dist.get_rank()))
|
||||||
|
optimizer.load_state_dict(load_optim['optim'])
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, gpc.config.NUM_EPOCHS):
|
||||||
|
model.train()
|
||||||
|
for index, (x, y) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False):
|
||||||
|
x, y = x.cuda(), y.cuda()
|
||||||
|
output = model(x)
|
||||||
|
loss = criterion(output, y)
|
||||||
|
loss = loss / gpc.config.gradient_accumulation
|
||||||
|
if use_ddp:
|
||||||
|
model.backward(loss)
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
if (index + 1) % gpc.config.gradient_accumulation == 0:
|
||||||
|
optimizer.step()
|
||||||
|
if use_ddp:
|
||||||
|
model.zero_grad()
|
||||||
|
else:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Finish Train Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {loss.item():.3f} lr: {optimizer.state_dict()['param_groups'][0]['lr']}",
|
||||||
|
ranks=[0])
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
test_loss = 0
|
||||||
|
correct = 0
|
||||||
|
test_sum = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for index, (x, y) in tqdm(enumerate(test_dataloader), total=len(test_dataloader), leave=False):
|
||||||
|
x, y = x.cuda(), y.cuda()
|
||||||
|
output = model(x)
|
||||||
|
test_loss += F.cross_entropy(output, y, reduction='sum').item()
|
||||||
|
pred = output.argmax(dim=1, keepdim=True)
|
||||||
|
correct += pred.eq(y.view_as(pred)).sum().item()
|
||||||
|
test_sum += y.size(0)
|
||||||
|
|
||||||
|
test_loss /= test_sum
|
||||||
|
logger.info(
|
||||||
|
f"Finish Test Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {test_loss:.3f} Accuracy: [{correct}/{test_sum}]({correct/test_sum:.3f})",
|
||||||
|
ranks=[0])
|
||||||
|
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
train_imagenet()
|
@ -0,0 +1,67 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from utils.dummy_data_generator import DummyDataGenerator
|
||||||
|
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
from transformers import ViTConfig, ViTForImageClassification
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataLoader(DummyDataGenerator):
|
||||||
|
batch_size = 4
|
||||||
|
channel = 3
|
||||||
|
category = 8
|
||||||
|
image_size = 224
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
image_dict = {}
|
||||||
|
image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size,
|
||||||
|
DummyDataLoader.channel,
|
||||||
|
DummyDataLoader.image_size,
|
||||||
|
DummyDataLoader.image_size,
|
||||||
|
device=get_current_device()) * 2 - 1
|
||||||
|
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=get_current_device())
|
||||||
|
return image_dict
|
||||||
|
|
||||||
|
|
||||||
|
class ViTCVModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
image_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
num_channels=3,
|
||||||
|
num_labels=8,
|
||||||
|
checkpoint=False):
|
||||||
|
super().__init__()
|
||||||
|
self.checkpoint = checkpoint
|
||||||
|
self.model = ViTForImageClassification(
|
||||||
|
ViTConfig(hidden_size=hidden_size,
|
||||||
|
num_hidden_layers=num_hidden_layers,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
image_size=image_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
num_channels=num_channels,
|
||||||
|
num_labels=num_labels))
|
||||||
|
if checkpoint:
|
||||||
|
self.model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
def forward(self, pixel_values):
|
||||||
|
return self.model(pixel_values=pixel_values)
|
||||||
|
|
||||||
|
|
||||||
|
def vit_base_s(checkpoint=True):
|
||||||
|
return ViTCVModel(checkpoint=checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
def vit_base_micro(checkpoint=True):
|
||||||
|
return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_components():
|
||||||
|
trainloader = DummyDataLoader()
|
||||||
|
testloader = DummyDataLoader()
|
||||||
|
return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy
|
Loading…
Reference in new issue