You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_zero_data_parallel/test_zero.py

154 lines
5.5 KiB

3 years ago
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os.path as osp
import pytest
import torch
from torch.utils.data import DataLoader
import colossalai
from colossalai.builder import build_dataset, build_loss, build_data_sampler, build_model
from colossalai.core import global_context
from colossalai.engine.gradient_handler import DataParallelGradientHandler
from colossalai.nn.optimizer import ZeroRedundancyOptimizer_Level_1, ZeroRedundancyOptimizer_Level_3, \
ZeroRedundancyOptimizer_Level_2
from colossalai.utils import print_rank_0
DIR_PATH = osp.dirname(osp.abspath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, 'config.py')
def run_dist():
colossalai.init_dist(CONFIG_PATH)
# build resnet model
model = build_model(global_context.config.model)
model.build_from_cfg()
model = model.cuda()
level = global_context.config.level
if level > 1:
model = model.half()
# test init cuda memory
_ = torch.rand(1).cuda()
torch.cuda.synchronize()
max_alloc = torch.cuda.max_memory_allocated()
max_reserved = torch.cuda.max_memory_reserved()
print(f'before run: max_allocation = {max_alloc}, max_reserved = {max_reserved}')
# build dataloader
train_dataset = build_dataset(global_context.config.train_data.dataset)
sampler_cfg = global_context.config.train_data.dataloader.pop('sampler', None)
if sampler_cfg is None:
train_dataloader = DataLoader(dataset=train_dataset, **global_context.config.train_data.dataloader)
else:
sampler = build_data_sampler(sampler_cfg, train_dataset)
train_dataloader = DataLoader(dataset=train_dataset, sampler=sampler,
**global_context.config.train_data.dataloader)
test_dataset = build_dataset(global_context.config.test_data.dataset)
test_dataloader = DataLoader(dataset=test_dataset, **global_context.config.test_data.dataloader)
# build optimizer and loss
# optimizer = build_optimizer(global_context.config.optimizer, model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
if level == 1:
zero_optim = ZeroRedundancyOptimizer_Level_1(init_optimizer=optimizer, verbose=False)
elif level == 2:
zero_optim = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, cpu_offload=True, verbose=False)
elif level == 3:
zero_optim = ZeroRedundancyOptimizer_Level_3(init_optimizer=optimizer,
module=model,
verbose=False,
offload_optimizer_config=dict(
device='cpu',
pin_memory=True,
buffer_count=5,
fast_init=False
),
offload_param_config=dict(
device='cpu',
pin_memory=True,
buffer_count=5,
buffer_size=1e8,
max_in_cpu=1e9
)
)
loss_fn = build_loss(global_context.config.loss)
gradient_handler = DataParallelGradientHandler(model, zero_optim)
# train
for epoch in range(100):
model.train()
# train
avg_train_loss = 0
train_iter = 0
for idx, (data, label) in enumerate(train_dataloader):
# model = model.half()
data = data[0].cuda()
label = label[0].cuda()
if level > 1:
data = data.half()
output = model(data)
loss = loss_fn(output[0], label)
if level > 1:
zero_optim.backward(loss)
zero_optim.overlapping_partition_gradients_reduce_epilogue()
else:
loss.backward()
gradient_handler.handle_gradient()
zero_optim.step()
zero_optim.zero_grad()
avg_train_loss += loss.detach().cpu().numpy()
train_iter += 1
print_rank_0(f'epoch: {epoch}, train loss: {avg_train_loss / train_iter}')
if epoch % 2 == 0:
model.eval()
avg_eval_loss = 0
correct = 0
total = 0
eval_iters = 0
for idx, (data, label) in enumerate(test_dataloader):
with torch.no_grad():
data = data[0].cuda()
label = label[0].cuda()
if level > 1:
data = data.half()
output = model(data)
loss = loss_fn(output[0], label)
avg_eval_loss += loss.detach().cpu().numpy()
preds = torch.argmax(output[0], dim=1)
total += data.size(0)
correct += sum(preds == label)
eval_iters += 1
print_rank_0(f'epoch: {epoch}, eval loss: {avg_eval_loss / eval_iters}, acc: {correct / total}')
@pytest.mark.skip("This test should be invoked manually using the script provided")
@pytest.mark.dist
def test_zero():
run_dist()
if __name__ == '__main__':
test_zero()