From ea2872073f365dc6461539d2f6c6d07081181ade Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 10 Mar 2022 11:20:04 +0800 Subject: [PATCH] [zero] global model data memory tracer (#360) --- colossalai/utils/commons/singleton_meta.py | 18 ++++++ colossalai/utils/memory_tracer/allocator.py | 60 +++++++++++++++++++ colossalai/zero/init_ctx/init_context.py | 10 +++- .../zero/sharded_param/sharded_tensor.py | 2 +- .../test_init_context.py | 8 ++- 5 files changed, 94 insertions(+), 4 deletions(-) create mode 100644 colossalai/utils/commons/singleton_meta.py create mode 100644 colossalai/utils/memory_tracer/allocator.py diff --git a/colossalai/utils/commons/singleton_meta.py b/colossalai/utils/commons/singleton_meta.py new file mode 100644 index 000000000..f4d3276e2 --- /dev/null +++ b/colossalai/utils/commons/singleton_meta.py @@ -0,0 +1,18 @@ +class SingletonMeta(type): + """ + The Singleton class can be implemented in different ways in Python. Some + possible methods include: base class, decorator, metaclass. We will use the + metaclass because it is best suited for this purpose. + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + """ + Possible changes to the value of the `__init__` argument do not affect + the returned instance. + """ + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] diff --git a/colossalai/utils/memory_tracer/allocator.py b/colossalai/utils/memory_tracer/allocator.py new file mode 100644 index 000000000..368aae2da --- /dev/null +++ b/colossalai/utils/memory_tracer/allocator.py @@ -0,0 +1,60 @@ +import torch +from colossalai.utils.commons.singleton_meta import SingletonMeta +from colossalai.zero.sharded_param import ShardedTensor + +from typing import Union + + +def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int: + if isinstance(t, ShardedTensor): + target = t.payload + else: + target = t + return target.numel() * target.element_size() + + +class ModelDataTracer(metaclass=SingletonMeta): + """ + A singleton to trace model data usage during runtime. + """ + + def __init__(self) -> None: + self._cpu_usage = 0 + self._cuda_usage = 0 + + def trace_tensor(self, t: torch.Tensor): + mem_use = col_tensor_mem_usage(t) + if t.device.type == 'cpu': + self._cpu_usage += mem_use + elif t.device.type == 'cuda': + self._cuda_usage += mem_use + else: + raise RuntimeError + + def detach_tensor(self, t: torch.Tensor): + mem_use = col_tensor_mem_usage(t) + if t.device.type == 'cpu': + self._cpu_usage -= mem_use + elif t.device.type == 'cuda': + self._cuda_usage -= mem_use + else: + raise RuntimeError + + @property + def cpu_usage(self): + return self._cpu_usage + + @property + def cuda_usage(self): + return self._cuda_usage + + +GLOBAL_MODEL_DATA_TRACER = ModelDataTracer() + + +def col_allocate_payload(device: torch.device) -> torch.Tensor: + pass + + +def col_release_payload(t: torch.Tensor): + pass diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 619168229..f045e144f 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -4,6 +4,7 @@ import torch from colossalai.utils.cuda import get_current_device from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_param import ShardedParamV2 +from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER # Inserts _post_init_method at the end of init method @@ -76,11 +77,16 @@ class InsertPostInitMethodToModuleSubClasses(object): class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): - """ + r""" A context to initialize model. 1. Convert the model to fp16. 2. The paramaters of the module are adapted to type ShardedParameter. 3. Shard the param and grad according to flags. + rm_torch_payload_on_the_fly: + True: remove tensor payload on param.data after module init finished. + False: remove tensor payload on param.data afther the context exist. + This is used when you add some logic to operate tensors in __init__ of module. + See torchvision resnet18. """ def __init__(self, @@ -134,5 +140,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if self.shard_param: self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor]) + GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._data_sharded_tensor.payload) if param.col_attr.grad and self.shard_grad: self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor]) + GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._grad_sharded_tensor.payload) diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/sharded_param/sharded_tensor.py index 823222725..093889b4b 100644 --- a/colossalai/zero/sharded_param/sharded_tensor.py +++ b/colossalai/zero/sharded_param/sharded_tensor.py @@ -7,7 +7,7 @@ class ShardedTensor(object): def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None: r""" - A tensor sharded in multiple processes. + A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance. """ self._payload = tensor self.process_group = process_group diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py index b181c7a5f..e29a266eb 100644 --- a/tests/test_zero_data_parallel/test_init_context.py +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -13,7 +13,8 @@ from colossalai.zero.shard_utils.tensor_shard_strategy import \ TensorShardStrategy from tests.components_to_test.registry import non_distributed_component_funcs -from common import CONFIG, Net +from common import CONFIG +from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER def run_dist(rank, world_size, port): @@ -33,9 +34,12 @@ def run_dist(rank, world_size, port): assert param.col_attr.data.is_sharded assert param.col_attr.data.payload.device.type == 'cuda' + print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}') + assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0) + @pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) def test_zero_init_context(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size)