[test] bert test in non-distributed way (#2074)

pull/2075/head^2
Jiarui Fang 2022-12-05 13:32:16 +08:00 committed by GitHub
parent 223332ff7e
commit 616ed91ecd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 6 deletions

View File

@ -68,16 +68,17 @@ def get_training_components():
return model
is_distrbuted = torch.distributed.is_initialized()
trainloader = get_bert_data_loader(n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distrbuted=True)
is_distrbuted=is_distrbuted)
testloader = get_bert_data_loader(n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distrbuted=True)
is_distrbuted=is_distrbuted)
criterion = None
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -21,14 +21,15 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc
model.backward(loss)
def run_param_wrapper_testing():
test_models = ['simple_net', 'repeated_computed_layers', 'nested_model']
def test_runtime_mem_tracer():
test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model']
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func()
with ColoInitContext(device=torch.device('cpu')):
model = model_builder(checkpoint=False)
model = model_builder(checkpoint=True)
model_bk = deepcopy(model)
runtime_mem_tracer = RuntimeMemTracer(model)
@ -52,4 +53,4 @@ def run_param_wrapper_testing():
if __name__ == '__main__':
run_param_wrapper_testing()
test_runtime_mem_tracer()