mirror of https://github.com/hpcaitech/ColossalAI
polish engine unitest
parent
354c0f9047
commit
d271f2596b
|
@ -15,7 +15,6 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
|
|||
|
||||
|
||||
def run_train():
|
||||
assert non_distributed_component_funcs.get_callable('bert')
|
||||
for get_components_func in non_distributed_component_funcs:
|
||||
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func()
|
||||
|
||||
|
@ -27,12 +26,15 @@ def run_train():
|
|||
|
||||
try:
|
||||
engine.train()
|
||||
for img, label in train_dataloader:
|
||||
for data, label in train_dataloader:
|
||||
engine.zero_grad()
|
||||
img = img.cuda()
|
||||
data = data.cuda()
|
||||
label = label.cuda()
|
||||
output = engine(img)
|
||||
loss = engine.criterion(output, label)
|
||||
if criterion:
|
||||
output = engine(data)
|
||||
loss = engine.criterion(output, label)
|
||||
else:
|
||||
loss = engine(data, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
break
|
||||
|
@ -72,9 +74,9 @@ def run_engine(rank, world_size, port):
|
|||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_no_amp()
|
||||
# run_with_torch_amp()
|
||||
# run_with_apex_amp()
|
||||
# run_with_naive_amp()
|
||||
run_with_torch_amp()
|
||||
run_with_apex_amp()
|
||||
run_with_naive_amp()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -76,6 +76,7 @@ def run_dist(rank, world_size, port):
|
|||
check_grads(model, zero_model, loose=True)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Under development")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
||||
def test_shard_model_v2(world_size):
|
||||
|
|
Loading…
Reference in New Issue