mirror of https://github.com/hpcaitech/ColossalAI
polish engine unitest
parent
354c0f9047
commit
d271f2596b
|
@ -15,7 +15,6 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
|
||||||
|
|
||||||
|
|
||||||
def run_train():
|
def run_train():
|
||||||
assert non_distributed_component_funcs.get_callable('bert')
|
|
||||||
for get_components_func in non_distributed_component_funcs:
|
for get_components_func in non_distributed_component_funcs:
|
||||||
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func()
|
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func()
|
||||||
|
|
||||||
|
@ -27,12 +26,15 @@ def run_train():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
engine.train()
|
engine.train()
|
||||||
for img, label in train_dataloader:
|
for data, label in train_dataloader:
|
||||||
engine.zero_grad()
|
engine.zero_grad()
|
||||||
img = img.cuda()
|
data = data.cuda()
|
||||||
label = label.cuda()
|
label = label.cuda()
|
||||||
output = engine(img)
|
if criterion:
|
||||||
loss = engine.criterion(output, label)
|
output = engine(data)
|
||||||
|
loss = engine.criterion(output, label)
|
||||||
|
else:
|
||||||
|
loss = engine(data, label)
|
||||||
engine.backward(loss)
|
engine.backward(loss)
|
||||||
engine.step()
|
engine.step()
|
||||||
break
|
break
|
||||||
|
@ -72,9 +74,9 @@ def run_engine(rank, world_size, port):
|
||||||
# init dist env
|
# init dist env
|
||||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
run_with_no_amp()
|
run_with_no_amp()
|
||||||
# run_with_torch_amp()
|
run_with_torch_amp()
|
||||||
# run_with_apex_amp()
|
run_with_apex_amp()
|
||||||
# run_with_naive_amp()
|
run_with_naive_amp()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
|
|
@ -76,6 +76,7 @@ def run_dist(rank, world_size, port):
|
||||||
check_grads(model, zero_model, loose=True)
|
check_grads(model, zero_model, loose=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Under development")
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
||||||
def test_shard_model_v2(world_size):
|
def test_shard_model_v2(world_size):
|
||||||
|
|
Loading…
Reference in New Issue