polish engine unitest

pull/394/head
jiaruifang 3 years ago committed by Frank Lee
parent 354c0f9047
commit d271f2596b

@ -15,7 +15,6 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
def run_train(): def run_train():
assert non_distributed_component_funcs.get_callable('bert')
for get_components_func in non_distributed_component_funcs: for get_components_func in non_distributed_component_funcs:
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func() model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func()
@ -27,12 +26,15 @@ def run_train():
try: try:
engine.train() engine.train()
for img, label in train_dataloader: for data, label in train_dataloader:
engine.zero_grad() engine.zero_grad()
img = img.cuda() data = data.cuda()
label = label.cuda() label = label.cuda()
output = engine(img) if criterion:
loss = engine.criterion(output, label) output = engine(data)
loss = engine.criterion(output, label)
else:
loss = engine(data, label)
engine.backward(loss) engine.backward(loss)
engine.step() engine.step()
break break
@ -72,9 +74,9 @@ def run_engine(rank, world_size, port):
# init dist env # init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_no_amp() run_with_no_amp()
# run_with_torch_amp() run_with_torch_amp()
# run_with_apex_amp() run_with_apex_amp()
# run_with_naive_amp() run_with_naive_amp()
@pytest.mark.dist @pytest.mark.dist

@ -76,6 +76,7 @@ def run_dist(rank, world_size, port):
check_grads(model, zero_model, loose=True) check_grads(model, zero_model, loose=True)
@pytest.mark.skip(reason="Under development")
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4]) @pytest.mark.parametrize("world_size", [1, 2, 4])
def test_shard_model_v2(world_size): def test_shard_model_v2(world_size):

Loading…
Cancel
Save