2022-11-12 09:49:48 +00:00
|
|
|
import torch
|
2023-01-11 08:27:31 +00:00
|
|
|
import torch.nn as nn
|
|
|
|
from torchvision.models import resnet18
|
2022-11-12 09:49:48 +00:00
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
import colossalai
|
2023-09-18 08:31:06 +00:00
|
|
|
from colossalai.legacy.core import global_context as gpc
|
2022-11-12 09:49:48 +00:00
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
|
|
|
from colossalai.nn.optimizer import Lamb, Lars
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
class DummyDataloader:
|
2022-11-12 09:49:48 +00:00
|
|
|
def __init__(self, length, batch_size):
|
|
|
|
self.length = length
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
|
|
def generate(self):
|
|
|
|
data = torch.rand(self.batch_size, 3, 224, 224)
|
|
|
|
label = torch.randint(low=0, high=10, size=(self.batch_size,))
|
|
|
|
return data, label
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
self.step = 0
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __next__(self):
|
|
|
|
if self.step < self.length:
|
|
|
|
self.step += 1
|
|
|
|
return self.generate()
|
|
|
|
else:
|
|
|
|
raise StopIteration
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return self.length
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
# initialize distributed setting
|
|
|
|
parser = colossalai.get_default_parser()
|
2023-09-19 06:20:26 +00:00
|
|
|
parser.add_argument(
|
|
|
|
"--optimizer", choices=["lars", "lamb"], help="Choose your large-batch optimizer", required=True
|
|
|
|
)
|
2022-11-12 09:49:48 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
# launch from torch
|
|
|
|
colossalai.launch_from_torch(config=args.config)
|
|
|
|
|
|
|
|
# get logger
|
|
|
|
logger = get_dist_logger()
|
|
|
|
logger.info("initialized distributed environment", ranks=[0])
|
|
|
|
|
2023-01-11 08:27:31 +00:00
|
|
|
# create synthetic dataloaders
|
|
|
|
train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
|
|
|
|
test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
|
|
|
|
|
|
|
|
# build model
|
|
|
|
model = resnet18(num_classes=gpc.config.NUM_CLASSES)
|
2022-11-12 09:49:48 +00:00
|
|
|
|
|
|
|
# create loss function
|
2023-01-11 08:27:31 +00:00
|
|
|
criterion = nn.CrossEntropyLoss()
|
2022-11-12 09:49:48 +00:00
|
|
|
|
|
|
|
# create optimizer
|
2023-01-11 08:27:31 +00:00
|
|
|
if args.optimizer == "lars":
|
|
|
|
optim_cls = Lars
|
|
|
|
elif args.optimizer == "lamb":
|
|
|
|
optim_cls = Lamb
|
|
|
|
optimizer = optim_cls(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
2022-11-12 09:49:48 +00:00
|
|
|
|
|
|
|
# create lr scheduler
|
2023-09-19 06:20:26 +00:00
|
|
|
lr_scheduler = CosineAnnealingWarmupLR(
|
|
|
|
optimizer=optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=gpc.config.WARMUP_EPOCHS
|
|
|
|
)
|
2022-11-12 09:49:48 +00:00
|
|
|
|
|
|
|
# initialize
|
2023-09-19 06:20:26 +00:00
|
|
|
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
|
|
|
|
model=model,
|
|
|
|
optimizer=optimizer,
|
|
|
|
criterion=criterion,
|
|
|
|
train_dataloader=train_dataloader,
|
|
|
|
test_dataloader=test_dataloader,
|
|
|
|
)
|
2022-11-12 09:49:48 +00:00
|
|
|
|
|
|
|
logger.info("Engine is built", ranks=[0])
|
|
|
|
|
|
|
|
for epoch in range(gpc.config.NUM_EPOCHS):
|
|
|
|
# training
|
|
|
|
engine.train()
|
|
|
|
data_iter = iter(train_dataloader)
|
|
|
|
|
|
|
|
if gpc.get_global_rank() == 0:
|
2023-09-19 06:20:26 +00:00
|
|
|
description = "Epoch {} / {}".format(epoch, gpc.config.NUM_EPOCHS)
|
2022-11-12 09:49:48 +00:00
|
|
|
progress = tqdm(range(len(train_dataloader)), desc=description)
|
|
|
|
else:
|
|
|
|
progress = range(len(train_dataloader))
|
|
|
|
for _ in progress:
|
|
|
|
engine.zero_grad()
|
|
|
|
engine.execute_schedule(data_iter, return_output_label=False)
|
|
|
|
engine.step()
|
|
|
|
lr_scheduler.step()
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if __name__ == "__main__":
|
2022-11-12 09:49:48 +00:00
|
|
|
main()
|