mirror of https://github.com/hpcaitech/ColossalAI
[zero] global model data memory tracer (#360)
parent
cb34cd384d
commit
ea2872073f
|
@ -0,0 +1,18 @@
|
|||
class SingletonMeta(type):
|
||||
"""
|
||||
The Singleton class can be implemented in different ways in Python. Some
|
||||
possible methods include: base class, decorator, metaclass. We will use the
|
||||
metaclass because it is best suited for this purpose.
|
||||
"""
|
||||
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
"""
|
||||
Possible changes to the value of the `__init__` argument do not affect
|
||||
the returned instance.
|
||||
"""
|
||||
if cls not in cls._instances:
|
||||
instance = super().__call__(*args, **kwargs)
|
||||
cls._instances[cls] = instance
|
||||
return cls._instances[cls]
|
|
@ -0,0 +1,60 @@
|
|||
import torch
|
||||
from colossalai.utils.commons.singleton_meta import SingletonMeta
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
|
||||
if isinstance(t, ShardedTensor):
|
||||
target = t.payload
|
||||
else:
|
||||
target = t
|
||||
return target.numel() * target.element_size()
|
||||
|
||||
|
||||
class ModelDataTracer(metaclass=SingletonMeta):
|
||||
"""
|
||||
A singleton to trace model data usage during runtime.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._cpu_usage = 0
|
||||
self._cuda_usage = 0
|
||||
|
||||
def trace_tensor(self, t: torch.Tensor):
|
||||
mem_use = col_tensor_mem_usage(t)
|
||||
if t.device.type == 'cpu':
|
||||
self._cpu_usage += mem_use
|
||||
elif t.device.type == 'cuda':
|
||||
self._cuda_usage += mem_use
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
def detach_tensor(self, t: torch.Tensor):
|
||||
mem_use = col_tensor_mem_usage(t)
|
||||
if t.device.type == 'cpu':
|
||||
self._cpu_usage -= mem_use
|
||||
elif t.device.type == 'cuda':
|
||||
self._cuda_usage -= mem_use
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
@property
|
||||
def cpu_usage(self):
|
||||
return self._cpu_usage
|
||||
|
||||
@property
|
||||
def cuda_usage(self):
|
||||
return self._cuda_usage
|
||||
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
||||
|
||||
|
||||
def col_allocate_payload(device: torch.device) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
def col_release_payload(t: torch.Tensor):
|
||||
pass
|
|
@ -4,6 +4,7 @@ import torch
|
|||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
|
||||
# Inserts _post_init_method at the end of init method
|
||||
|
@ -76,11 +77,16 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|||
|
||||
|
||||
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
"""
|
||||
r"""
|
||||
A context to initialize model.
|
||||
1. Convert the model to fp16.
|
||||
2. The paramaters of the module are adapted to type ShardedParameter.
|
||||
3. Shard the param and grad according to flags.
|
||||
rm_torch_payload_on_the_fly:
|
||||
True: remove tensor payload on param.data after module init finished.
|
||||
False: remove tensor payload on param.data afther the context exist.
|
||||
This is used when you add some logic to operate tensors in __init__ of module.
|
||||
See torchvision resnet18.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -134,5 +140,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
||||
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||
if param.col_attr.grad and self.shard_grad:
|
||||
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
||||
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
|
|
|
@ -7,7 +7,7 @@ class ShardedTensor(object):
|
|||
|
||||
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
r"""
|
||||
A tensor sharded in multiple processes.
|
||||
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
||||
"""
|
||||
self._payload = tensor
|
||||
self.process_group = process_group
|
||||
|
|
|
@ -13,7 +13,8 @@ from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
|||
TensorShardStrategy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG, Net
|
||||
from common import CONFIG
|
||||
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
@ -33,9 +34,12 @@ def run_dist(rank, world_size, port):
|
|||
assert param.col_attr.data.is_sharded
|
||||
assert param.col_attr.data.payload.device.type == 'cuda'
|
||||
|
||||
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
def test_zero_init_context(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
Loading…
Reference in New Issue