import argparse import os from pathlib import Path import torch from titans.utils import barrier_context from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models import resnet50 from tqdm import tqdm import colossalai from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingLR from colossalai.utils import get_dataloader DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute() def synthesize_data(): img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32) label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,)) return img, label def main(): colossalai.launch_from_torch(config='./config.py') logger = get_dist_logger() # trace the model with meta data model = resnet50(num_classes=10).cuda() input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} model = autoparallelize(model, input_sample) # build criterion criterion = torch.nn.CrossEntropyLoss() # optimizer optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) # lr_scheduler lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) for epoch in range(gpc.config.NUM_EPOCHS): model.train() # if we use synthetic data # we assume it only has 30 steps per epoch num_steps = range(30) progress = tqdm(num_steps) for _ in progress: # generate fake data img, label = synthesize_data() img = img.cuda() label = label.cuda() optimizer.zero_grad() output = model(img) train_loss = criterion(output, label) train_loss.backward(train_loss) optimizer.step() lr_scheduler.step() # run evaluation model.eval() correct = 0 total = 0 # if we use synthetic data # we assume it only has 10 steps for evaluation num_steps = range(30) progress = tqdm(num_steps) for _ in progress: # generate fake data img, label = synthesize_data() img = img.cuda() label = label.cuda() with torch.no_grad(): output = model(img) test_loss = criterion(output, label) pred = torch.argmax(output, dim=-1) correct += torch.sum(pred == label) total += img.size(0) logger.info( f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", ranks=[0]) if __name__ == '__main__': main()