[zero] global model data memory tracer (#360)

pull/394/head
Jiarui Fang 2022-03-10 11:20:04 +08:00 committed by Frank Lee
parent cb34cd384d
commit ea2872073f
5 changed files with 94 additions and 4 deletions

View File

@ -0,0 +1,18 @@
class SingletonMeta(type):
"""
The Singleton class can be implemented in different ways in Python. Some
possible methods include: base class, decorator, metaclass. We will use the
metaclass because it is best suited for this purpose.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
"""
Possible changes to the value of the `__init__` argument do not affect
the returned instance.
"""
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]

View File

@ -0,0 +1,60 @@
import torch
from colossalai.utils.commons.singleton_meta import SingletonMeta
from colossalai.zero.sharded_param import ShardedTensor
from typing import Union
def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
if isinstance(t, ShardedTensor):
target = t.payload
else:
target = t
return target.numel() * target.element_size()
class ModelDataTracer(metaclass=SingletonMeta):
"""
A singleton to trace model data usage during runtime.
"""
def __init__(self) -> None:
self._cpu_usage = 0
self._cuda_usage = 0
def trace_tensor(self, t: torch.Tensor):
mem_use = col_tensor_mem_usage(t)
if t.device.type == 'cpu':
self._cpu_usage += mem_use
elif t.device.type == 'cuda':
self._cuda_usage += mem_use
else:
raise RuntimeError
def detach_tensor(self, t: torch.Tensor):
mem_use = col_tensor_mem_usage(t)
if t.device.type == 'cpu':
self._cpu_usage -= mem_use
elif t.device.type == 'cuda':
self._cuda_usage -= mem_use
else:
raise RuntimeError
@property
def cpu_usage(self):
return self._cpu_usage
@property
def cuda_usage(self):
return self._cuda_usage
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
def col_allocate_payload(device: torch.device) -> torch.Tensor:
pass
def col_release_payload(t: torch.Tensor):
pass

View File

@ -4,6 +4,7 @@ import torch
from colossalai.utils.cuda import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
# Inserts _post_init_method at the end of init method
@ -76,11 +77,16 @@ class InsertPostInitMethodToModuleSubClasses(object):
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
r"""
A context to initialize model.
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
rm_torch_payload_on_the_fly:
True: remove tensor payload on param.data after module init finished.
False: remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18.
"""
def __init__(self,
@ -134,5 +140,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param:
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._data_sharded_tensor.payload)
if param.col_attr.grad and self.shard_grad:
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._grad_sharded_tensor.payload)

View File

@ -7,7 +7,7 @@ class ShardedTensor(object):
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
r"""
A tensor sharded in multiple processes.
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
self._payload = tensor
self.process_group = process_group

View File

@ -13,7 +13,8 @@ from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG, Net
from common import CONFIG
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
def run_dist(rank, world_size, port):
@ -33,9 +34,12 @@ def run_dist(rank, world_size, port):
assert param.col_attr.data.is_sharded
assert param.col_attr.data.payload.device.type == 'cuda'
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4])
@pytest.mark.parametrize("world_size", [1, 4])
def test_zero_init_context(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)