diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py index 63fa2740f..c1faa6f9d 100644 --- a/tests/components_to_test/bert.py +++ b/tests/components_to_test/bert.py @@ -68,16 +68,17 @@ def get_training_components(): return model + is_distrbuted = torch.distributed.is_initialized() trainloader = get_bert_data_loader(n_class=vocab_size, batch_size=2, total_samples=10000, sequence_length=sequence_length, - is_distrbuted=True) + is_distrbuted=is_distrbuted) testloader = get_bert_data_loader(n_class=vocab_size, batch_size=2, total_samples=10000, sequence_length=sequence_length, - is_distrbuted=True) + is_distrbuted=is_distrbuted) criterion = None return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/test_gemini/test_runtime_mem_tracer.py b/tests/test_gemini/test_runtime_mem_tracer.py index 0b112f66f..a494b8f59 100644 --- a/tests/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_gemini/test_runtime_mem_tracer.py @@ -21,14 +21,15 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc model.backward(loss) -def run_param_wrapper_testing(): - test_models = ['simple_net', 'repeated_computed_layers', 'nested_model'] +def test_runtime_mem_tracer(): + test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model'] + for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, _, _, criterion = get_components_func() with ColoInitContext(device=torch.device('cpu')): - model = model_builder(checkpoint=False) + model = model_builder(checkpoint=True) model_bk = deepcopy(model) runtime_mem_tracer = RuntimeMemTracer(model) @@ -52,4 +53,4 @@ def run_param_wrapper_testing(): if __name__ == '__main__': - run_param_wrapper_testing() + test_runtime_mem_tracer()