ColossalAI/examples/simclr_cifar10_data_parallel/train_simclr.py

101 lines
3.0 KiB
Python

import colossalai
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.utils import get_dataloader, MultiTimer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from torchvision.datasets import CIFAR10
from NT_Xentloss import NT_Xentloss
from myhooks import TotalBatchsizeHook
from models.simclr import SimCLR
from augmentation import SimCLRTransform
def build_dataset_train():
augment = SimCLRTransform()
train_dataset = CIFAR10(root=gpc.config.dataset.root,
transform=augment,
train=True,
download=True)
return get_dataloader(
dataset=train_dataset,
shuffle=True,
num_workers = 1,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)
def build_dataset_test():
augment = SimCLRTransform()
val_dataset = CIFAR10(root=gpc.config.dataset.root,
transform=augment,
train=False)
return get_dataloader(
dataset=val_dataset,
add_sampler=False,
num_workers = 1,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)
def main():
colossalai.launch_from_torch(config='./config.py')
# get logger
logger = get_dist_logger()
## build model
model = SimCLR(model='resnet18')
# build dataloader
train_dataloader = build_dataset_train()
test_dataloader = build_dataset_test()
# build loss
criterion = NT_Xentloss()
# build optimizer
optimizer = colossalai.nn.FusedSGD(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY, momentum=gpc.config.MOMENTUM)
# lr_scheduelr
lr_scheduler = CosineAnnealingWarmupLR(optimizer, warmup_steps=10, total_steps=gpc.config.NUM_EPOCHS)
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
model, optimizer, criterion, train_dataloader, test_dataloader
)
logger.info("initialized colossalai components", ranks=[0])
# build a timer to measure time
timer = MultiTimer()
# build trainer
trainer = Trainer(engine=engine, logger=logger, timer=timer)
# build hooks
hook_list = [
hooks.LossHook(),
hooks.LogMetricByEpochHook(logger),
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
TotalBatchsizeHook(),
# comment if you do not need to use the hooks below
hooks.SaveCheckpointHook(interval=50, checkpoint_dir=f'./ckpt/{gpc.config.LOG_NAME}'),
hooks.TensorboardHook(log_dir=f'./tb_logs/{gpc.config.LOG_NAME}', ranks=[0]),
]
# start training
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks=hook_list,
display_progress=True,
test_interval=1
)
if __name__ == '__main__':
main()