mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
154 lines
5.5 KiB
154 lines
5.5 KiB
3 years ago
|
#!/usr/bin/env python
|
||
|
# -*- encoding: utf-8 -*-
|
||
|
|
||
|
import os.path as osp
|
||
|
|
||
|
import pytest
|
||
|
import torch
|
||
|
from torch.utils.data import DataLoader
|
||
|
|
||
|
import colossalai
|
||
|
from colossalai.builder import build_dataset, build_loss, build_data_sampler, build_model
|
||
|
from colossalai.core import global_context
|
||
|
from colossalai.engine.gradient_handler import DataParallelGradientHandler
|
||
|
from colossalai.nn.optimizer import ZeroRedundancyOptimizer_Level_1, ZeroRedundancyOptimizer_Level_3, \
|
||
|
ZeroRedundancyOptimizer_Level_2
|
||
|
from colossalai.utils import print_rank_0
|
||
|
|
||
|
DIR_PATH = osp.dirname(osp.abspath(__file__))
|
||
|
CONFIG_PATH = osp.join(DIR_PATH, 'config.py')
|
||
|
|
||
|
|
||
|
def run_dist():
|
||
|
colossalai.init_dist(CONFIG_PATH)
|
||
|
|
||
|
# build resnet model
|
||
|
model = build_model(global_context.config.model)
|
||
|
model.build_from_cfg()
|
||
|
model = model.cuda()
|
||
|
|
||
|
level = global_context.config.level
|
||
|
|
||
|
if level > 1:
|
||
|
model = model.half()
|
||
|
|
||
|
# test init cuda memory
|
||
|
_ = torch.rand(1).cuda()
|
||
|
torch.cuda.synchronize()
|
||
|
max_alloc = torch.cuda.max_memory_allocated()
|
||
|
max_reserved = torch.cuda.max_memory_reserved()
|
||
|
print(f'before run: max_allocation = {max_alloc}, max_reserved = {max_reserved}')
|
||
|
|
||
|
# build dataloader
|
||
|
train_dataset = build_dataset(global_context.config.train_data.dataset)
|
||
|
|
||
|
sampler_cfg = global_context.config.train_data.dataloader.pop('sampler', None)
|
||
|
if sampler_cfg is None:
|
||
|
train_dataloader = DataLoader(dataset=train_dataset, **global_context.config.train_data.dataloader)
|
||
|
else:
|
||
|
sampler = build_data_sampler(sampler_cfg, train_dataset)
|
||
|
train_dataloader = DataLoader(dataset=train_dataset, sampler=sampler,
|
||
|
**global_context.config.train_data.dataloader)
|
||
|
|
||
|
test_dataset = build_dataset(global_context.config.test_data.dataset)
|
||
|
test_dataloader = DataLoader(dataset=test_dataset, **global_context.config.test_data.dataloader)
|
||
|
|
||
|
# build optimizer and loss
|
||
|
# optimizer = build_optimizer(global_context.config.optimizer, model)
|
||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||
|
if level == 1:
|
||
|
zero_optim = ZeroRedundancyOptimizer_Level_1(init_optimizer=optimizer, verbose=False)
|
||
|
elif level == 2:
|
||
|
zero_optim = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, cpu_offload=True, verbose=False)
|
||
|
elif level == 3:
|
||
|
zero_optim = ZeroRedundancyOptimizer_Level_3(init_optimizer=optimizer,
|
||
|
module=model,
|
||
|
verbose=False,
|
||
|
offload_optimizer_config=dict(
|
||
|
device='cpu',
|
||
|
pin_memory=True,
|
||
|
buffer_count=5,
|
||
|
fast_init=False
|
||
|
),
|
||
|
offload_param_config=dict(
|
||
|
device='cpu',
|
||
|
pin_memory=True,
|
||
|
buffer_count=5,
|
||
|
buffer_size=1e8,
|
||
|
max_in_cpu=1e9
|
||
|
)
|
||
|
)
|
||
|
|
||
|
loss_fn = build_loss(global_context.config.loss)
|
||
|
gradient_handler = DataParallelGradientHandler(model, zero_optim)
|
||
|
|
||
|
# train
|
||
|
for epoch in range(100):
|
||
|
model.train()
|
||
|
|
||
|
# train
|
||
|
avg_train_loss = 0
|
||
|
train_iter = 0
|
||
|
|
||
|
for idx, (data, label) in enumerate(train_dataloader):
|
||
|
# model = model.half()
|
||
|
data = data[0].cuda()
|
||
|
label = label[0].cuda()
|
||
|
|
||
|
if level > 1:
|
||
|
data = data.half()
|
||
|
|
||
|
output = model(data)
|
||
|
loss = loss_fn(output[0], label)
|
||
|
|
||
|
if level > 1:
|
||
|
zero_optim.backward(loss)
|
||
|
zero_optim.overlapping_partition_gradients_reduce_epilogue()
|
||
|
else:
|
||
|
loss.backward()
|
||
|
gradient_handler.handle_gradient()
|
||
|
|
||
|
zero_optim.step()
|
||
|
zero_optim.zero_grad()
|
||
|
|
||
|
avg_train_loss += loss.detach().cpu().numpy()
|
||
|
train_iter += 1
|
||
|
|
||
|
print_rank_0(f'epoch: {epoch}, train loss: {avg_train_loss / train_iter}')
|
||
|
|
||
|
if epoch % 2 == 0:
|
||
|
model.eval()
|
||
|
avg_eval_loss = 0
|
||
|
correct = 0
|
||
|
total = 0
|
||
|
eval_iters = 0
|
||
|
|
||
|
for idx, (data, label) in enumerate(test_dataloader):
|
||
|
with torch.no_grad():
|
||
|
data = data[0].cuda()
|
||
|
label = label[0].cuda()
|
||
|
|
||
|
if level > 1:
|
||
|
data = data.half()
|
||
|
|
||
|
output = model(data)
|
||
|
loss = loss_fn(output[0], label)
|
||
|
|
||
|
avg_eval_loss += loss.detach().cpu().numpy()
|
||
|
preds = torch.argmax(output[0], dim=1)
|
||
|
total += data.size(0)
|
||
|
correct += sum(preds == label)
|
||
|
eval_iters += 1
|
||
|
|
||
|
print_rank_0(f'epoch: {epoch}, eval loss: {avg_eval_loss / eval_iters}, acc: {correct / total}')
|
||
|
|
||
|
|
||
|
@pytest.mark.skip("This test should be invoked manually using the script provided")
|
||
|
@pytest.mark.dist
|
||
|
def test_zero():
|
||
|
run_dist()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
test_zero()
|