diff --git a/colossalai/gemini/memory_tracer/module_tracer_wrapper.py b/colossalai/gemini/memory_tracer/module_tracer_wrapper.py index 9967df627..ab139516c 100644 --- a/colossalai/gemini/memory_tracer/module_tracer_wrapper.py +++ b/colossalai/gemini/memory_tracer/module_tracer_wrapper.py @@ -28,6 +28,9 @@ class _Wrapper(): def show_mem_stats(self): self._ophook_list[0].show_mem_stats() + def named_buffers(self): + return self._model.named_buffers() + def MemtracerWrapper(model): ophook_list = [MemTracerOpHook()] diff --git a/colossalai/gemini/ophooks/mem_trace_hook.py b/colossalai/gemini/ophooks/mem_trace_hook.py index 49982b175..697655259 100644 --- a/colossalai/gemini/ophooks/mem_trace_hook.py +++ b/colossalai/gemini/ophooks/mem_trace_hook.py @@ -7,6 +7,7 @@ from colossalai.gemini.ophooks import BaseOpHook class MemTracerOpHook(BaseOpHook): """ TODO() what if parameters are sharded by multiple submodules. + register buff on its father node """ def __init__(self): diff --git a/tests/test_gemini/test_mem_tracer.py b/tests/test_gemini/test_mem_tracer.py index c7700d9d7..05da462a4 100644 --- a/tests/test_gemini/test_mem_tracer.py +++ b/tests/test_gemini/test_mem_tracer.py @@ -1,8 +1,13 @@ +from functools import partial + +import pytest import torch -import torch.nn as nn +import torch.multiprocessing as mp import colossalai from colossalai.gemini.memory_tracer import MemtracerWrapper +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port from tests.components_to_test.registry import non_distributed_component_funcs @@ -17,16 +22,20 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): model.backward(loss) -def test_tracer(): - # reset the manager, in case that there exists memory information left - test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module'] +def run_tracer(rank, world_size, port, grad_check=True): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module', 'bert'] for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, _, _, criterion = get_components_func() # init model on cpu - model = MemtracerWrapper(model_builder()) + # TODO() memtrace hook can not handle buff registered on a non-leaf module (for example the BertEmbedding). + # a simple method is that always puts buff on cuda and viewed them as non-model data. + model = MemtracerWrapper(model_builder(grad_check)) + for n, buff in model.named_buffers(): + buff.data = buff.data.cuda() for i, (data, label) in enumerate(train_dataloader): if i > 1: break @@ -38,5 +47,13 @@ def test_tracer(): # model._ophook_list[0].print_non_model_data() +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1]) +@rerun_if_address_is_in_use() +def test_tracer(world_size): + run_func = partial(run_tracer, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + if __name__ == '__main__': - test_tracer() + test_tracer(1) diff --git a/tests/test_zero/test_shard_model_v2.py b/tests/test_zero/test_shard_model_v2.py index 654c82a46..d77a78e8e 100644 --- a/tests/test_zero/test_shard_model_v2.py +++ b/tests/test_zero/test_shard_model_v2.py @@ -3,21 +3,21 @@ from functools import partial -import colossalai import pytest import torch import torch.multiprocessing as mp +from common import CONFIG, check_grads_padding, run_fwd_bwd +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) +from colossalai.zero.shard_utils import BucketTensorShardStrategy from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP - -from common import CONFIG, check_grads_padding, run_fwd_bwd @parameterize("enable_autocast", [True])