polish code

pull/394/head
jiaruifang 2022-03-09 11:35:11 +08:00 committed by Frank Lee
parent 4d94cd513e
commit 354c0f9047
1 changed files with 5 additions and 2 deletions

View File

@ -44,7 +44,9 @@ def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
test_models = ['bert']
# repeated_computed_layers resnet18
shard_strategy = TensorShardStrategy() shard_strategy = TensorShardStrategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
@ -58,11 +60,12 @@ def run_dist(rank, world_size, port):
if i > 2: if i > 2:
break break
if model_name == 'bert': if criterion is None:
data, label = data.cuda(), label.cuda() data, label = data.cuda(), label.cuda()
run_fwd_bwd_no_criterion(model, data, label, False) run_fwd_bwd_no_criterion(model, data, label, False)
run_fwd_bwd_no_criterion(zero_model, data, label, False) run_fwd_bwd_no_criterion(zero_model, data, label, False)
else: else:
# FIXME() data can be interger!
data, label = data.half().cuda(), label.cuda() data, label = data.half().cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False) run_fwd_bwd(model, data, label, criterion, False)
run_fwd_bwd(zero_model, data, label, criterion, False) run_fwd_bwd(zero_model, data, label, criterion, False)