mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* [ColoTensor] ColoInitContext initialize parameters in shard mode. * polish * [example] add vitpull/1944/head
Jiarui Fang
2 years ago
committed by
GitHub
6 changed files with 468 additions and 0 deletions
@ -0,0 +1,61 @@
|
||||
# Vision Transformer with ColoTensor |
||||
|
||||
# Overview |
||||
|
||||
In this example, we will run Vision Transformer with ColoTensor. |
||||
|
||||
We use model **ViTForImageClassification** from Hugging Face [Link](https://huggingface.co/docs/transformers/model_doc/vit) for unit test. |
||||
You can change world size or decide whether use DDP in our code. |
||||
|
||||
We use model **vision_transformer** from timm [Link](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) for training example. |
||||
|
||||
(2022/6/28) The default configuration now supports 2DP+2TP with gradient accumulation and checkpoint support. Zero is not supported at present. |
||||
|
||||
# Requirement |
||||
|
||||
You should install colossalai from main branch with commit 561e904. |
||||
|
||||
## Unit test |
||||
To run unit test, you should install pytest, transformers with: |
||||
```shell |
||||
pip install pytest transformers |
||||
``` |
||||
|
||||
## Training example |
||||
To run training example with ViT-S, you should install **NVIDIA DALI** from [Link](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html) for dataloader support. |
||||
You also need to install timm and titans for model/dataloader support with: |
||||
```shell |
||||
pip install timm titans |
||||
``` |
||||
|
||||
### Data preparation |
||||
You can download the ImageNet dataset from the [ImageNet official website](https://www.image-net.org/download.php). You should get the raw images after downloading the dataset. As we use **NVIDIA DALI** to read data, we use the TFRecords dataset instead of raw Imagenet dataset. This offers better speedup to IO. If you don't have TFRecords dataset, follow [imagenet-tools](https://github.com/ver217/imagenet-tools) to build one. |
||||
|
||||
Before you start training, you need to set the environment variable `DATA` so that the script knows where to fetch the data for DALI dataloader. |
||||
```shell |
||||
export DATA=/path/to/ILSVRC2012 |
||||
``` |
||||
|
||||
|
||||
# How to run |
||||
|
||||
## Unit test |
||||
In your terminal |
||||
```shell |
||||
pytest test_vit.py |
||||
``` |
||||
|
||||
This will evaluate models with different **world_size** and **use_ddp**. |
||||
|
||||
## Training example |
||||
Modify the settings in run.sh according to your environment. |
||||
For example, if you set `--nproc_per_node=8` in `run.sh` and `TP_WORLD_SIZE=2` in your config file, |
||||
data parallel size will be automatically calculated as 4. |
||||
Thus, the parallel strategy is set to 4DP+2TP. |
||||
|
||||
Then in your terminal |
||||
```shell |
||||
sh run.sh |
||||
``` |
||||
|
||||
This will start ViT-S training with ImageNet. |
@ -0,0 +1,32 @@
|
||||
from colossalai.amp import AMP_TYPE |
||||
|
||||
# hyperparameters |
||||
# BATCH_SIZE is as per GPU |
||||
# global batch size = BATCH_SIZE x data parallel size |
||||
BATCH_SIZE = 256 |
||||
LEARNING_RATE = 3e-3 |
||||
WEIGHT_DECAY = 0.3 |
||||
NUM_EPOCHS = 300 |
||||
WARMUP_EPOCHS = 32 |
||||
|
||||
# model config |
||||
IMG_SIZE = 224 |
||||
PATCH_SIZE = 16 |
||||
HIDDEN_SIZE = 384 |
||||
DEPTH = 12 |
||||
NUM_HEADS = 6 |
||||
MLP_RATIO = 4 |
||||
NUM_CLASSES = 1000 |
||||
CHECKPOINT = False |
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token |
||||
|
||||
USE_DDP = True |
||||
TP_WORLD_SIZE = 2 |
||||
TP_TYPE = 'row' |
||||
parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),) |
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE) |
||||
clip_grad_norm = 1.0 |
||||
gradient_accumulation = 8 |
||||
|
||||
LOG_PATH = "./log" |
@ -0,0 +1,15 @@
|
||||
export DATA=/data/scratch/imagenet/tf_records |
||||
export OMP_NUM_THREADS=4 |
||||
|
||||
# resume |
||||
# CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \ |
||||
# --nproc_per_node 4 train.py \ |
||||
# --config configs/vit_1d_tp2.py \ |
||||
# --resume_from checkpoint/epoch_10 \ |
||||
# --master_port 29598 | tee ./out 2>&1 |
||||
|
||||
# train |
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \ |
||||
--nproc_per_node 4 train.py \ |
||||
--config configs/vit_1d_tp2.py \ |
||||
--master_port 29598 | tee ./out 2>&1 |
@ -0,0 +1,132 @@
|
||||
from functools import partial |
||||
|
||||
import pytest |
||||
import torch |
||||
import torch.multiprocessing as mp |
||||
from torch.nn.parallel import DistributedDataParallel as DDP |
||||
from utils.util import set_seed, tensor_equal, tensor_shard_equal |
||||
from vit import get_training_components |
||||
|
||||
import colossalai |
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.nn.parallel.data_parallel import ColoDDP |
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec |
||||
from colossalai.testing import rerun_if_address_is_in_use |
||||
from colossalai.utils import free_port |
||||
from colossalai.utils.cuda import get_current_device |
||||
from colossalai.utils.model.colo_init_context import ColoInitContext |
||||
|
||||
|
||||
# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating. |
||||
# But for other layers, it's 1d_col split. |
||||
# Layernorm is not supported for now. |
||||
# patch_embeddings.projection has nn.Conv2d |
||||
# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182 |
||||
def init_1d_row_for_linear_weight_spec(model, world_size: int): |
||||
pg = ProcessGroup(tp_degree=world_size) |
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
||||
with DistSpecManager.no_grad(): |
||||
for n, p in model.named_parameters(): |
||||
if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n: |
||||
p.set_process_group(pg) |
||||
p.set_tensor_spec(*spec) |
||||
|
||||
|
||||
# Similarly, it's col split for Linear but row split for others. |
||||
def init_1d_col_for_linear_weight_bias_spec(model, world_size: int): |
||||
pg = ProcessGroup(tp_degree=world_size) |
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
||||
with DistSpecManager.no_grad(): |
||||
for n, p in model.named_parameters(): |
||||
if ('weight' in n |
||||
or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n: |
||||
p.set_process_group(pg) |
||||
p.set_tensor_spec(*spec) |
||||
|
||||
|
||||
def check_param_equal(model, torch_model): |
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()): |
||||
assert tensor_shard_equal(torch_p, p) |
||||
|
||||
|
||||
def check_grad_equal(model, torch_model): |
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()): |
||||
if (torch_p.grad.shape == p.grad.shape): |
||||
assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True |
||||
else: |
||||
dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape)) |
||||
dim = dims_not_eq.item() |
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) |
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) |
||||
assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True |
||||
|
||||
|
||||
def run_vit(init_spec_func, use_ddp): |
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components() |
||||
with ColoInitContext(device=get_current_device()): |
||||
model = model_builder() |
||||
model = model.cuda() |
||||
torch_model = model_builder().cuda() |
||||
if use_ddp: |
||||
model = ColoDDP(model) |
||||
torch_model = DDP(torch_model, |
||||
device_ids=[gpc.get_global_rank()], |
||||
process_group=gpc.get_group(ParallelMode.DATA)) |
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()): |
||||
torch_p.data.copy_(p) |
||||
|
||||
world_size = torch.distributed.get_world_size() |
||||
init_spec_func(model, world_size) |
||||
|
||||
check_param_equal(model, torch_model) |
||||
model.train() |
||||
torch_model.train() |
||||
set_seed(gpc.get_local_rank(ParallelMode.DATA)) |
||||
|
||||
optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) |
||||
torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) |
||||
|
||||
for i, image_dict in enumerate(train_dataloader): |
||||
if use_ddp: |
||||
model.zero_grad() |
||||
else: |
||||
optimizer.zero_grad() |
||||
logits = model(image_dict['pixel_values']) |
||||
torch_logits = torch_model(image_dict['pixel_values']) |
||||
assert tensor_equal(torch_logits.logits, logits.logits) |
||||
loss = criterion(logits.logits, image_dict['label']) |
||||
torch_loss = criterion(torch_logits.logits, image_dict['label']) |
||||
if use_ddp: |
||||
model.backward(loss) |
||||
else: |
||||
loss.backward() |
||||
torch_loss.backward() |
||||
check_grad_equal(model, torch_model) |
||||
optimizer.step() |
||||
torch_optimizer.step() |
||||
check_param_equal(model, torch_model) |
||||
break |
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp): |
||||
if use_ddp and world_size == 1: |
||||
return |
||||
tp_world_size = world_size // 2 if use_ddp else world_size |
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) |
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') |
||||
run_vit(init_1d_row_for_linear_weight_spec, use_ddp) |
||||
run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp) |
||||
|
||||
|
||||
@pytest.mark.dist |
||||
@pytest.mark.parametrize('world_size', [1, 4]) |
||||
@pytest.mark.parametrize('use_ddp', [False, True]) |
||||
@rerun_if_address_is_in_use() |
||||
def test_vit(world_size, use_ddp): |
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) |
||||
mp.spawn(run_func, nprocs=world_size) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
test_vit(1, False) |
@ -0,0 +1,161 @@
|
||||
import os |
||||
|
||||
import torch |
||||
import torch.distributed as dist |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
from timm.models.vision_transformer import _create_vision_transformer |
||||
from titans.dataloader.imagenet import build_dali_imagenet |
||||
from tqdm import tqdm |
||||
|
||||
import colossalai |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger |
||||
from colossalai.nn import CrossEntropyLoss |
||||
from colossalai.nn._ops import * |
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR |
||||
from colossalai.nn.optimizer import HybridAdam |
||||
from colossalai.nn.parallel.data_parallel import ColoDDP |
||||
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec |
||||
from colossalai.utils import get_current_device |
||||
from colossalai.utils.model.colo_init_context import ColoInitContext |
||||
|
||||
|
||||
def init_1d_row_for_linear_weight_spec(model, world_size: int): |
||||
pg = ProcessGroup(tp_degree=world_size) |
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
||||
with DistSpecManager.no_grad(): |
||||
for n, p in model.named_parameters(): |
||||
if 'weight' in n and 'norm' not in n and 'patch_embed.proj.weight' not in n: |
||||
p.set_process_group(pg) |
||||
p.set_tensor_spec(*spec) |
||||
|
||||
|
||||
# Similarly, it's col split for Linear but row split for others. |
||||
def init_1d_col_for_linear_weight_bias_spec(model, world_size: int): |
||||
pg = ProcessGroup(tp_degree=world_size) |
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
||||
with DistSpecManager.no_grad(): |
||||
for n, p in model.named_parameters(): |
||||
if ('weight' in n or 'bias' in n) and 'norm' not in n and ('patch_embed.proj.weight' not in n |
||||
and 'patch_embed.proj.bias' not in n): |
||||
p.set_process_group(pg) |
||||
p.set_tensor_spec(*spec) |
||||
|
||||
|
||||
def init_spec_func(model, tp_type): |
||||
world_size = torch.distributed.get_world_size() |
||||
if tp_type == 'row': |
||||
init_1d_row_for_linear_weight_spec(model, world_size) |
||||
elif tp_type == 'col': |
||||
init_1d_col_for_linear_weight_bias_spec(model, world_size) |
||||
else: |
||||
raise NotImplemented |
||||
|
||||
|
||||
def train_imagenet(): |
||||
|
||||
parser = colossalai.get_default_parser() |
||||
parser.add_argument('--from_torch', default=True, action='store_true') |
||||
parser.add_argument('--resume_from', default=False) |
||||
|
||||
args = parser.parse_args() |
||||
colossalai.launch_from_torch(config=args.config) |
||||
use_ddp = gpc.config.USE_DDP |
||||
|
||||
disable_existing_loggers() |
||||
|
||||
logger = get_dist_logger() |
||||
if hasattr(gpc.config, 'LOG_PATH'): |
||||
if gpc.get_global_rank() == 0: |
||||
log_path = gpc.config.LOG_PATH |
||||
if not os.path.exists(log_path): |
||||
os.mkdir(log_path) |
||||
logger.log_to_file(log_path) |
||||
|
||||
logger.info('Build data loader', ranks=[0]) |
||||
root = os.environ['DATA'] |
||||
train_dataloader, test_dataloader = build_dali_imagenet(root, |
||||
train_batch_size=gpc.config.BATCH_SIZE, |
||||
test_batch_size=gpc.config.BATCH_SIZE) |
||||
|
||||
logger.info('Build model', ranks=[0]) |
||||
|
||||
model_kwargs = dict(img_size=gpc.config.IMG_SIZE, |
||||
patch_size=gpc.config.PATCH_SIZE, |
||||
embed_dim=gpc.config.HIDDEN_SIZE, |
||||
depth=gpc.config.DEPTH, |
||||
num_heads=gpc.config.NUM_HEADS, |
||||
mlp_ratio=gpc.config.MLP_RATIO, |
||||
num_classes=gpc.config.NUM_CLASSES, |
||||
drop_rate=0.1, |
||||
attn_drop_rate=0.1, |
||||
weight_init='jax') |
||||
|
||||
with ColoInitContext(device=get_current_device()): |
||||
model = _create_vision_transformer('vit_small_patch16_224', pretrained=False, **model_kwargs) |
||||
init_spec_func(model, gpc.config.TP_TYPE) |
||||
|
||||
world_size = torch.distributed.get_world_size() |
||||
model = ColoDDP(module=model, process_group=ProcessGroup(tp_degree=world_size)) |
||||
logger.info('Build criterion, optimizer, lr_scheduler', ranks=[0]) |
||||
optimizer = HybridAdam(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) |
||||
|
||||
criterion = CrossEntropyLoss() |
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, |
||||
total_steps=gpc.config.NUM_EPOCHS, |
||||
warmup_steps=gpc.config.WARMUP_EPOCHS) |
||||
|
||||
start_epoch = 0 |
||||
if args.resume_from: |
||||
load_model = torch.load(args.resume_from + '_model.pth') |
||||
start_epoch = load_model['epoch'] |
||||
model.load_state_dict(load_model['model']) |
||||
load_optim = torch.load(args.resume_from + '_optim_rank_{}.pth'.format(dist.get_rank())) |
||||
optimizer.load_state_dict(load_optim['optim']) |
||||
|
||||
for epoch in range(start_epoch, gpc.config.NUM_EPOCHS): |
||||
model.train() |
||||
for index, (x, y) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False): |
||||
x, y = x.cuda(), y.cuda() |
||||
output = model(x) |
||||
loss = criterion(output, y) |
||||
loss = loss / gpc.config.gradient_accumulation |
||||
if use_ddp: |
||||
model.backward(loss) |
||||
else: |
||||
loss.backward() |
||||
if (index + 1) % gpc.config.gradient_accumulation == 0: |
||||
optimizer.step() |
||||
if use_ddp: |
||||
model.zero_grad() |
||||
else: |
||||
optimizer.zero_grad() |
||||
|
||||
logger.info( |
||||
f"Finish Train Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {loss.item():.3f} lr: {optimizer.state_dict()['param_groups'][0]['lr']}", |
||||
ranks=[0]) |
||||
|
||||
model.eval() |
||||
test_loss = 0 |
||||
correct = 0 |
||||
test_sum = 0 |
||||
with torch.no_grad(): |
||||
for index, (x, y) in tqdm(enumerate(test_dataloader), total=len(test_dataloader), leave=False): |
||||
x, y = x.cuda(), y.cuda() |
||||
output = model(x) |
||||
test_loss += F.cross_entropy(output, y, reduction='sum').item() |
||||
pred = output.argmax(dim=1, keepdim=True) |
||||
correct += pred.eq(y.view_as(pred)).sum().item() |
||||
test_sum += y.size(0) |
||||
|
||||
test_loss /= test_sum |
||||
logger.info( |
||||
f"Finish Test Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {test_loss:.3f} Accuracy: [{correct}/{test_sum}]({correct/test_sum:.3f})", |
||||
ranks=[0]) |
||||
|
||||
lr_scheduler.step() |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
train_imagenet() |
@ -0,0 +1,67 @@
|
||||
import torch |
||||
import torch.nn as nn |
||||
from utils.dummy_data_generator import DummyDataGenerator |
||||
|
||||
from colossalai.utils.cuda import get_current_device |
||||
from transformers import ViTConfig, ViTForImageClassification |
||||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator): |
||||
batch_size = 4 |
||||
channel = 3 |
||||
category = 8 |
||||
image_size = 224 |
||||
|
||||
def generate(self): |
||||
image_dict = {} |
||||
image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size, |
||||
DummyDataLoader.channel, |
||||
DummyDataLoader.image_size, |
||||
DummyDataLoader.image_size, |
||||
device=get_current_device()) * 2 - 1 |
||||
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,), |
||||
dtype=torch.int64, |
||||
device=get_current_device()) |
||||
return image_dict |
||||
|
||||
|
||||
class ViTCVModel(nn.Module): |
||||
|
||||
def __init__(self, |
||||
hidden_size=768, |
||||
num_hidden_layers=12, |
||||
num_attention_heads=12, |
||||
image_size=224, |
||||
patch_size=16, |
||||
num_channels=3, |
||||
num_labels=8, |
||||
checkpoint=False): |
||||
super().__init__() |
||||
self.checkpoint = checkpoint |
||||
self.model = ViTForImageClassification( |
||||
ViTConfig(hidden_size=hidden_size, |
||||
num_hidden_layers=num_hidden_layers, |
||||
num_attention_heads=num_attention_heads, |
||||
image_size=image_size, |
||||
patch_size=patch_size, |
||||
num_channels=num_channels, |
||||
num_labels=num_labels)) |
||||
if checkpoint: |
||||
self.model.gradient_checkpointing_enable() |
||||
|
||||
def forward(self, pixel_values): |
||||
return self.model(pixel_values=pixel_values) |
||||
|
||||
|
||||
def vit_base_s(checkpoint=True): |
||||
return ViTCVModel(checkpoint=checkpoint) |
||||
|
||||
|
||||
def vit_base_micro(checkpoint=True): |
||||
return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint) |
||||
|
||||
|
||||
def get_training_components(): |
||||
trainloader = DummyDataLoader() |
||||
testloader = DummyDataLoader() |
||||
return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy |
Loading…
Reference in new issue