Browse Source

polish engine unitest

pull/394/head
jiaruifang 3 years ago committed by Frank Lee
parent
commit
d271f2596b
  1. 18
      tests/test_engine/test_engine.py
  2. 1
      tests/test_zero_data_parallel/test_shard_model_v2.py

18
tests/test_engine/test_engine.py

@ -15,7 +15,6 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
def run_train():
assert non_distributed_component_funcs.get_callable('bert')
for get_components_func in non_distributed_component_funcs:
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func()
@ -27,12 +26,15 @@ def run_train():
try:
engine.train()
for img, label in train_dataloader:
for data, label in train_dataloader:
engine.zero_grad()
img = img.cuda()
data = data.cuda()
label = label.cuda()
output = engine(img)
loss = engine.criterion(output, label)
if criterion:
output = engine(data)
loss = engine.criterion(output, label)
else:
loss = engine(data, label)
engine.backward(loss)
engine.step()
break
@ -72,9 +74,9 @@ def run_engine(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_no_amp()
# run_with_torch_amp()
# run_with_apex_amp()
# run_with_naive_amp()
run_with_torch_amp()
run_with_apex_amp()
run_with_naive_amp()
@pytest.mark.dist

1
tests/test_zero_data_parallel/test_shard_model_v2.py

@ -76,6 +76,7 @@ def run_dist(rank, world_size, port):
check_grads(model, zero_model, loose=True)
@pytest.mark.skip(reason="Under development")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4])
def test_shard_model_v2(world_size):

Loading…
Cancel
Save