mirror of https://github.com/hpcaitech/ColossalAI
[zero] global model data memory tracer (#360)
parent
cb34cd384d
commit
ea2872073f
|
@ -0,0 +1,18 @@
|
||||||
|
class SingletonMeta(type):
|
||||||
|
"""
|
||||||
|
The Singleton class can be implemented in different ways in Python. Some
|
||||||
|
possible methods include: base class, decorator, metaclass. We will use the
|
||||||
|
metaclass because it is best suited for this purpose.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instances = {}
|
||||||
|
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Possible changes to the value of the `__init__` argument do not affect
|
||||||
|
the returned instance.
|
||||||
|
"""
|
||||||
|
if cls not in cls._instances:
|
||||||
|
instance = super().__call__(*args, **kwargs)
|
||||||
|
cls._instances[cls] = instance
|
||||||
|
return cls._instances[cls]
|
|
@ -0,0 +1,60 @@
|
||||||
|
import torch
|
||||||
|
from colossalai.utils.commons.singleton_meta import SingletonMeta
|
||||||
|
from colossalai.zero.sharded_param import ShardedTensor
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
|
||||||
|
if isinstance(t, ShardedTensor):
|
||||||
|
target = t.payload
|
||||||
|
else:
|
||||||
|
target = t
|
||||||
|
return target.numel() * target.element_size()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDataTracer(metaclass=SingletonMeta):
|
||||||
|
"""
|
||||||
|
A singleton to trace model data usage during runtime.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._cpu_usage = 0
|
||||||
|
self._cuda_usage = 0
|
||||||
|
|
||||||
|
def trace_tensor(self, t: torch.Tensor):
|
||||||
|
mem_use = col_tensor_mem_usage(t)
|
||||||
|
if t.device.type == 'cpu':
|
||||||
|
self._cpu_usage += mem_use
|
||||||
|
elif t.device.type == 'cuda':
|
||||||
|
self._cuda_usage += mem_use
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
def detach_tensor(self, t: torch.Tensor):
|
||||||
|
mem_use = col_tensor_mem_usage(t)
|
||||||
|
if t.device.type == 'cpu':
|
||||||
|
self._cpu_usage -= mem_use
|
||||||
|
elif t.device.type == 'cuda':
|
||||||
|
self._cuda_usage -= mem_use
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cpu_usage(self):
|
||||||
|
return self._cpu_usage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cuda_usage(self):
|
||||||
|
return self._cuda_usage
|
||||||
|
|
||||||
|
|
||||||
|
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
||||||
|
|
||||||
|
|
||||||
|
def col_allocate_payload(device: torch.device) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def col_release_payload(t: torch.Tensor):
|
||||||
|
pass
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
|
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
|
||||||
|
|
||||||
# Inserts _post_init_method at the end of init method
|
# Inserts _post_init_method at the end of init method
|
||||||
|
@ -76,11 +77,16 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
||||||
|
|
||||||
|
|
||||||
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
"""
|
r"""
|
||||||
A context to initialize model.
|
A context to initialize model.
|
||||||
1. Convert the model to fp16.
|
1. Convert the model to fp16.
|
||||||
2. The paramaters of the module are adapted to type ShardedParameter.
|
2. The paramaters of the module are adapted to type ShardedParameter.
|
||||||
3. Shard the param and grad according to flags.
|
3. Shard the param and grad according to flags.
|
||||||
|
rm_torch_payload_on_the_fly:
|
||||||
|
True: remove tensor payload on param.data after module init finished.
|
||||||
|
False: remove tensor payload on param.data afther the context exist.
|
||||||
|
This is used when you add some logic to operate tensors in __init__ of module.
|
||||||
|
See torchvision resnet18.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -134,5 +140,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||||
if param.col_attr.grad and self.shard_grad:
|
if param.col_attr.grad and self.shard_grad:
|
||||||
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||||
|
|
|
@ -7,7 +7,7 @@ class ShardedTensor(object):
|
||||||
|
|
||||||
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
A tensor sharded in multiple processes.
|
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
||||||
"""
|
"""
|
||||||
self._payload = tensor
|
self._payload = tensor
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
|
@ -13,7 +13,8 @@ from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
||||||
TensorShardStrategy
|
TensorShardStrategy
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
from common import CONFIG, Net
|
from common import CONFIG
|
||||||
|
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
|
@ -33,9 +34,12 @@ def run_dist(rank, world_size, port):
|
||||||
assert param.col_attr.data.is_sharded
|
assert param.col_attr.data.is_sharded
|
||||||
assert param.col_attr.data.payload.device.type == 'cuda'
|
assert param.col_attr.data.payload.device.type == 'cuda'
|
||||||
|
|
||||||
|
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
|
||||||
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
@pytest.mark.parametrize("world_size", [1, 4])
|
||||||
def test_zero_init_context(world_size):
|
def test_zero_init_context(world_size):
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
Loading…
Reference in New Issue