Browse Source

[zero] global model data memory tracer (#360)

pull/394/head
Jiarui Fang 3 years ago committed by Frank Lee
parent
commit
ea2872073f
  1. 18
      colossalai/utils/commons/singleton_meta.py
  2. 60
      colossalai/utils/memory_tracer/allocator.py
  3. 10
      colossalai/zero/init_ctx/init_context.py
  4. 2
      colossalai/zero/sharded_param/sharded_tensor.py
  5. 8
      tests/test_zero_data_parallel/test_init_context.py

18
colossalai/utils/commons/singleton_meta.py

@ -0,0 +1,18 @@
class SingletonMeta(type):
"""
The Singleton class can be implemented in different ways in Python. Some
possible methods include: base class, decorator, metaclass. We will use the
metaclass because it is best suited for this purpose.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
"""
Possible changes to the value of the `__init__` argument do not affect
the returned instance.
"""
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]

60
colossalai/utils/memory_tracer/allocator.py

@ -0,0 +1,60 @@
import torch
from colossalai.utils.commons.singleton_meta import SingletonMeta
from colossalai.zero.sharded_param import ShardedTensor
from typing import Union
def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
if isinstance(t, ShardedTensor):
target = t.payload
else:
target = t
return target.numel() * target.element_size()
class ModelDataTracer(metaclass=SingletonMeta):
"""
A singleton to trace model data usage during runtime.
"""
def __init__(self) -> None:
self._cpu_usage = 0
self._cuda_usage = 0
def trace_tensor(self, t: torch.Tensor):
mem_use = col_tensor_mem_usage(t)
if t.device.type == 'cpu':
self._cpu_usage += mem_use
elif t.device.type == 'cuda':
self._cuda_usage += mem_use
else:
raise RuntimeError
def detach_tensor(self, t: torch.Tensor):
mem_use = col_tensor_mem_usage(t)
if t.device.type == 'cpu':
self._cpu_usage -= mem_use
elif t.device.type == 'cuda':
self._cuda_usage -= mem_use
else:
raise RuntimeError
@property
def cpu_usage(self):
return self._cpu_usage
@property
def cuda_usage(self):
return self._cuda_usage
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
def col_allocate_payload(device: torch.device) -> torch.Tensor:
pass
def col_release_payload(t: torch.Tensor):
pass

10
colossalai/zero/init_ctx/init_context.py

@ -4,6 +4,7 @@ import torch
from colossalai.utils.cuda import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
# Inserts _post_init_method at the end of init method
@ -76,11 +77,16 @@ class InsertPostInitMethodToModuleSubClasses(object):
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
r"""
A context to initialize model.
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
rm_torch_payload_on_the_fly:
True: remove tensor payload on param.data after module init finished.
False: remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18.
"""
def __init__(self,
@ -134,5 +140,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param:
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._data_sharded_tensor.payload)
if param.col_attr.grad and self.shard_grad:
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._grad_sharded_tensor.payload)

2
colossalai/zero/sharded_param/sharded_tensor.py

@ -7,7 +7,7 @@ class ShardedTensor(object):
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
r"""
A tensor sharded in multiple processes.
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
self._payload = tensor
self.process_group = process_group

8
tests/test_zero_data_parallel/test_init_context.py

@ -13,7 +13,8 @@ from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG, Net
from common import CONFIG
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
def run_dist(rank, world_size, port):
@ -33,9 +34,12 @@ def run_dist(rank, world_size, port):
assert param.col_attr.data.is_sharded
assert param.col_attr.data.payload.device.type == 'cuda'
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4])
@pytest.mark.parametrize("world_size", [1, 4])
def test_zero_init_context(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

Loading…
Cancel
Save