diff --git a/colossalai/gemini/memory_tracer/__init__.py b/colossalai/gemini/memory_tracer/__init__.py index 8bbf1678e..d12461353 100644 --- a/colossalai/gemini/memory_tracer/__init__.py +++ b/colossalai/gemini/memory_tracer/__init__.py @@ -3,9 +3,8 @@ from .memstats_collector import MemStatsCollector # isort:skip from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip from .static_memstats_collector import StaticMemStatsCollector # isort:skip -from .module_tracer_wrapper import MemtracerWrapper # isort:skip __all__ = [ 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', - 'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER', 'MemtracerWrapper' + 'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER' ] diff --git a/colossalai/gemini/memory_tracer/module_tracer_wrapper.py b/colossalai/gemini/memory_tracer/module_tracer_wrapper.py deleted file mode 100644 index ab139516c..000000000 --- a/colossalai/gemini/memory_tracer/module_tracer_wrapper.py +++ /dev/null @@ -1,39 +0,0 @@ -from colossalai.gemini.ophooks import register_ophooks_recursively -from colossalai.gemini.ophooks.mem_trace_hook import MemTracerOpHook - -__all__ = ['MemtracerWrapper'] - - -class _Wrapper(): - - def __init__(self, model, ophook_list): - self._ophook_list = ophook_list - self._model = model - - def __call__(self, *args, **kwargs): - return self._model(*args, **kwargs) - - def forward(self, *args, **kwargs): - return self._model.forward(*args, **kwargs) - - def backward(self, loss): - loss.backward() - for ophook in self._ophook_list: - ophook.post_iter() - - def save_results(self, filename): - for ophook in self._ophook_list: - ophook.save_results(filename) - - def show_mem_stats(self): - self._ophook_list[0].show_mem_stats() - - def named_buffers(self): - return self._model.named_buffers() - - -def MemtracerWrapper(model): - ophook_list = [MemTracerOpHook()] - register_ophooks_recursively(model, ophook_list) - engine = _Wrapper(model, ophook_list) - return engine diff --git a/tests/test_gemini/test_mem_tracer.py b/tests/test_gemini/test_mem_tracer.py deleted file mode 100644 index c777308c1..000000000 --- a/tests/test_gemini/test_mem_tracer.py +++ /dev/null @@ -1,51 +0,0 @@ -from functools import partial - -import pytest -import torch -import torch.multiprocessing as mp - -import colossalai -from colossalai.gemini.memory_tracer import MemtracerWrapper -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from tests.components_to_test import run_fwd_bwd -from tests.components_to_test.registry import non_distributed_component_funcs - - -def run_tracer(rank, world_size, port, use_grad_check=True): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - test_models = ['repeated_computed_layers', 'resnet18', 'hanging_param_model', 'bert'] - # test_models = ['bert'] - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, _, criterion = get_components_func() - - # init model on cpu - # TODO() memtrace hook can not handle buff registered on a non-leaf module (for example the BertEmbedding). - # a simple method is that always puts buff on cuda and viewed them as non-model data. - model = MemtracerWrapper(model_builder(checkpoint=use_grad_check)) - - for n, buff in model.named_buffers(): - buff.data = buff.data.cuda() - for i, (data, label) in enumerate(train_dataloader): - if i > 1: - break - data = data.cuda() - label = label.cuda() - - run_fwd_bwd(model, data, label, criterion) - - model._ophook_list[0].print_non_model_data() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1]) -@pytest.mark.parametrize("use_grad_check", [True, False]) -@rerun_if_address_is_in_use() -def test_tracer(world_size, use_grad_check): - run_func = partial(run_tracer, world_size=world_size, port=free_port(), use_grad_check=use_grad_check) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_tracer(1, True)