2022-11-10 07:31:46 +00:00
|
|
|
import torch
|
|
|
|
from torchvision.models import resnet50
|
|
|
|
from tqdm import tqdm
|
2022-11-11 09:23:40 +00:00
|
|
|
|
|
|
|
import colossalai
|
2023-01-31 09:42:45 +00:00
|
|
|
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
2022-11-11 09:23:40 +00:00
|
|
|
from colossalai.core import global_context as gpc
|
2023-01-31 09:42:45 +00:00
|
|
|
from colossalai.device.device_mesh import DeviceMesh
|
2022-11-11 09:23:40 +00:00
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
from colossalai.nn.lr_scheduler import CosineAnnealingLR
|
2022-11-10 07:31:46 +00:00
|
|
|
|
|
|
|
|
2022-11-12 09:14:32 +00:00
|
|
|
def synthesize_data():
|
2022-12-01 10:50:58 +00:00
|
|
|
img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32)
|
|
|
|
label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,))
|
2022-11-12 09:14:32 +00:00
|
|
|
return img, label
|
|
|
|
|
|
|
|
|
2022-11-10 07:31:46 +00:00
|
|
|
def main():
|
2022-12-01 10:50:58 +00:00
|
|
|
colossalai.launch_from_torch(config='./config.py')
|
2022-11-10 07:31:46 +00:00
|
|
|
|
|
|
|
logger = get_dist_logger()
|
|
|
|
|
|
|
|
# trace the model with meta data
|
|
|
|
model = resnet50(num_classes=10).cuda()
|
2023-01-31 09:42:45 +00:00
|
|
|
|
2022-12-01 10:50:58 +00:00
|
|
|
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
|
2023-01-31 09:42:45 +00:00
|
|
|
device_mesh = DeviceMesh(physical_mesh_id=torch.tensor([0, 1, 2, 3]), mesh_shape=[2, 2], init_process_group=True)
|
|
|
|
model, solution = initialize_model(model, input_sample, device_mesh=device_mesh, return_solution=True)
|
2022-11-10 07:31:46 +00:00
|
|
|
|
2023-01-31 09:42:45 +00:00
|
|
|
if gpc.get_global_rank() == 0:
|
|
|
|
for node_strategy in solution:
|
|
|
|
print(node_strategy)
|
2022-11-10 07:31:46 +00:00
|
|
|
# build criterion
|
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
# optimizer
|
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
|
|
|
|
|
|
|
|
# lr_scheduler
|
2022-12-01 10:50:58 +00:00
|
|
|
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)
|
2022-11-10 07:31:46 +00:00
|
|
|
|
2022-12-01 10:50:58 +00:00
|
|
|
for epoch in range(gpc.config.NUM_EPOCHS):
|
2023-01-12 06:30:58 +00:00
|
|
|
model.train()
|
2022-11-12 09:14:32 +00:00
|
|
|
|
2023-01-12 06:30:58 +00:00
|
|
|
# if we use synthetic data
|
2023-01-12 08:26:42 +00:00
|
|
|
# we assume it only has 10 steps per epoch
|
|
|
|
num_steps = range(10)
|
2022-11-12 09:14:32 +00:00
|
|
|
progress = tqdm(num_steps)
|
|
|
|
|
|
|
|
for _ in progress:
|
2023-01-12 06:30:58 +00:00
|
|
|
# generate fake data
|
|
|
|
img, label = synthesize_data()
|
2022-11-12 09:14:32 +00:00
|
|
|
|
2022-11-10 07:31:46 +00:00
|
|
|
img = img.cuda()
|
|
|
|
label = label.cuda()
|
|
|
|
optimizer.zero_grad()
|
2023-01-12 06:30:58 +00:00
|
|
|
output = model(img)
|
2022-11-10 07:31:46 +00:00
|
|
|
train_loss = criterion(output, label)
|
|
|
|
train_loss.backward(train_loss)
|
2023-01-31 09:42:45 +00:00
|
|
|
torch.cuda.synchronize()
|
2022-11-10 07:31:46 +00:00
|
|
|
optimizer.step()
|
|
|
|
lr_scheduler.step()
|
|
|
|
|
2022-11-12 09:14:32 +00:00
|
|
|
# run evaluation
|
2023-01-12 06:30:58 +00:00
|
|
|
model.eval()
|
2022-11-10 07:31:46 +00:00
|
|
|
correct = 0
|
|
|
|
total = 0
|
2022-11-12 09:14:32 +00:00
|
|
|
|
2023-01-12 06:30:58 +00:00
|
|
|
# if we use synthetic data
|
|
|
|
# we assume it only has 10 steps for evaluation
|
2023-01-12 08:26:42 +00:00
|
|
|
num_steps = range(10)
|
2022-11-12 09:14:32 +00:00
|
|
|
progress = tqdm(num_steps)
|
|
|
|
|
|
|
|
for _ in progress:
|
2023-01-12 06:30:58 +00:00
|
|
|
# generate fake data
|
|
|
|
img, label = synthesize_data()
|
2022-11-12 09:14:32 +00:00
|
|
|
|
2022-11-10 07:31:46 +00:00
|
|
|
img = img.cuda()
|
|
|
|
label = label.cuda()
|
|
|
|
|
|
|
|
with torch.no_grad():
|
2023-01-12 06:30:58 +00:00
|
|
|
output = model(img)
|
2022-11-10 07:31:46 +00:00
|
|
|
test_loss = criterion(output, label)
|
|
|
|
pred = torch.argmax(output, dim=-1)
|
|
|
|
correct += torch.sum(pred == label)
|
|
|
|
total += img.size(0)
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}",
|
|
|
|
ranks=[0])
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|