mirror of https://github.com/hpcaitech/ColossalAI
[example] update vit ci script (#2469)
* [example] update vit ci script * [example] update requirements * [example] update requirementspull/2476/head
parent
867c8c2d3a
commit
8e85d2440a
|
@ -0,0 +1,32 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
# hyperparameters
|
||||
# BATCH_SIZE is as per GPU
|
||||
# global batch size = BATCH_SIZE x data parallel size
|
||||
BATCH_SIZE = 8
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
NUM_EPOCHS = 3
|
||||
WARMUP_EPOCHS = 1
|
||||
|
||||
# model config
|
||||
IMG_SIZE = 224
|
||||
PATCH_SIZE = 16
|
||||
HIDDEN_SIZE = 32
|
||||
DEPTH = 2
|
||||
NUM_HEADS = 4
|
||||
MLP_RATIO = 4
|
||||
NUM_CLASSES = 10
|
||||
CHECKPOINT = False
|
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
||||
|
||||
USE_DDP = True
|
||||
TP_WORLD_SIZE = 2
|
||||
TP_TYPE = 'row'
|
||||
parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
clip_grad_norm = 1.0
|
||||
gradient_accumulation = 2
|
||||
|
||||
LOG_PATH = "./log_ci"
|
|
@ -1,2 +1,8 @@
|
|||
colossalai >= 0.1.12
|
||||
torch >= 1.8.1
|
||||
numpy>=1.24.1
|
||||
timm>=0.6.12
|
||||
titans>=0.0.7
|
||||
tqdm>=4.61.2
|
||||
transformers>=4.25.1
|
||||
nvidia-dali-cuda110>=1.8.0 --extra-index-url https://developer.download.nvidia.com/compute/redist
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
export OMP_NUM_THREADS=4
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
# train
|
||||
colossalai run \
|
||||
--nproc_per_node 4 train.py \
|
||||
--config configs/vit_1d_tp2_ci.py \
|
||||
--dummy_data
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|||
from timm.models.vision_transformer import _create_vision_transformer
|
||||
from titans.dataloader.imagenet import build_dali_imagenet
|
||||
from tqdm import tqdm
|
||||
from vit import DummyDataLoader
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
|
@ -56,8 +57,8 @@ def init_spec_func(model, tp_type):
|
|||
def train_imagenet():
|
||||
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument('--from_torch', default=True, action='store_true')
|
||||
parser.add_argument('--resume_from', default=False)
|
||||
parser.add_argument('--resume_from', default=False, action='store_true')
|
||||
parser.add_argument('--dummy_data', default=False, action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
@ -74,10 +75,22 @@ def train_imagenet():
|
|||
logger.log_to_file(log_path)
|
||||
|
||||
logger.info('Build data loader', ranks=[0])
|
||||
root = os.environ['DATA']
|
||||
train_dataloader, test_dataloader = build_dali_imagenet(root,
|
||||
train_batch_size=gpc.config.BATCH_SIZE,
|
||||
test_batch_size=gpc.config.BATCH_SIZE)
|
||||
if not args.dummy_data:
|
||||
root = os.environ['DATA']
|
||||
train_dataloader, test_dataloader = build_dali_imagenet(root,
|
||||
train_batch_size=gpc.config.BATCH_SIZE,
|
||||
test_batch_size=gpc.config.BATCH_SIZE)
|
||||
else:
|
||||
train_dataloader = DummyDataLoader(length=10,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
category=gpc.config.NUM_CLASSES,
|
||||
image_size=gpc.config.IMG_SIZE,
|
||||
return_dict=False)
|
||||
test_dataloader = DummyDataLoader(length=5,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
category=gpc.config.NUM_CLASSES,
|
||||
image_size=gpc.config.IMG_SIZE,
|
||||
return_dict=False)
|
||||
|
||||
logger.info('Build model', ranks=[0])
|
||||
|
||||
|
|
|
@ -32,21 +32,24 @@ class DummyDataGenerator(ABC):
|
|||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator):
|
||||
batch_size = 4
|
||||
channel = 3
|
||||
category = 8
|
||||
image_size = 224
|
||||
|
||||
def __init__(self, length=10, batch_size=4, channel=3, category=8, image_size=224, return_dict=True):
|
||||
super().__init__(length)
|
||||
self.batch_size = batch_size
|
||||
self.channel = channel
|
||||
self.category = category
|
||||
self.image_size = image_size
|
||||
self.return_dict = return_dict
|
||||
|
||||
def generate(self):
|
||||
image_dict = {}
|
||||
image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size,
|
||||
DummyDataLoader.channel,
|
||||
DummyDataLoader.image_size,
|
||||
DummyDataLoader.image_size,
|
||||
device=get_current_device()) * 2 - 1
|
||||
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
|
||||
image_dict['pixel_values'] = torch.rand(
|
||||
self.batch_size, self.channel, self.image_size, self.image_size, device=get_current_device()) * 2 - 1
|
||||
image_dict['label'] = torch.randint(self.category, (self.batch_size,),
|
||||
dtype=torch.int64,
|
||||
device=get_current_device())
|
||||
if not self.return_dict:
|
||||
return image_dict['pixel_values'], image_dict['label']
|
||||
return image_dict
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue