mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] remove GLOBAL_MODEL_DATA_TRACER (#2091)
parent
28e55c2530
commit
1fca5d79ea
|
@ -1,11 +1,10 @@
|
||||||
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
|
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
|
||||||
from .memstats_collector import MemStatsCollector # isort:skip
|
from .memstats_collector import MemStatsCollector # isort:skip
|
||||||
from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip
|
|
||||||
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
|
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
|
||||||
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
|
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
|
||||||
from .memory_stats import MemStats
|
from .memory_stats import MemStats
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
||||||
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER', 'MemStats'
|
'StaticMemStatsCollector', 'MemStats'
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,9 +2,6 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.context.singleton_meta import SingletonMeta
|
|
||||||
from colossalai.logging import DistributedLogger
|
|
||||||
|
|
||||||
|
|
||||||
def colo_model_optimizer_usage(optim) -> Tuple[int, int]:
|
def colo_model_optimizer_usage(optim) -> Tuple[int, int]:
|
||||||
"""Trace the optimizer memory usage
|
"""Trace the optimizer memory usage
|
||||||
|
@ -60,52 +57,3 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
cpu_mem_usage += t_cpu
|
cpu_mem_usage += t_cpu
|
||||||
|
|
||||||
return cuda_mem_usage, cpu_mem_usage
|
return cuda_mem_usage, cpu_mem_usage
|
||||||
|
|
||||||
|
|
||||||
class ModelDataTracer(metaclass=SingletonMeta):
|
|
||||||
"""
|
|
||||||
A tracer singleton to trace model data usage during runtime.
|
|
||||||
You have to register a model on the singleton first.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._logger = DistributedLogger("ModelDataTracer")
|
|
||||||
self._model = None
|
|
||||||
self._opitimizer = None
|
|
||||||
|
|
||||||
def _get_mem_usage(self) -> Tuple[int, int]:
|
|
||||||
"""
|
|
||||||
get the memory usage of the model registered.
|
|
||||||
Returns:
|
|
||||||
Tuple[int, int]: cuda, cpu mem usage
|
|
||||||
"""
|
|
||||||
cuda_use_opt, cpu_use_opt = colo_model_optimizer_usage(self._opitimizer)
|
|
||||||
cuda_use_model, cpu_use_model = colo_model_mem_usage(self._model)
|
|
||||||
return cuda_use_opt + cuda_use_model, cpu_use_opt + cpu_use_model
|
|
||||||
|
|
||||||
def register_model(self, model) -> None:
|
|
||||||
if self._model is not None:
|
|
||||||
self._logger.warning("ModelDataTracer has already registered a model")
|
|
||||||
self._model = model
|
|
||||||
|
|
||||||
def register_optimizer(self, optimizer) -> None:
|
|
||||||
if self._opitimizer is not None:
|
|
||||||
self._logger.warning("ModelDataTracer has already registered an optimizer")
|
|
||||||
self._opitimizer = optimizer
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cpu_usage(self):
|
|
||||||
_, cpu_usage = self._get_mem_usage()
|
|
||||||
return cpu_usage
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cuda_usage(self):
|
|
||||||
cuda_usage, _ = self._get_mem_usage()
|
|
||||||
return cuda_usage
|
|
||||||
|
|
||||||
@property
|
|
||||||
def both_mem_usage(self):
|
|
||||||
return self._get_mem_usage()
|
|
||||||
|
|
||||||
|
|
||||||
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
|
|
@ -1,121 +0,0 @@
|
||||||
import torch
|
|
||||||
import colossalai
|
|
||||||
import pytest
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from colossalai.utils.cuda import get_current_device
|
|
||||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
|
||||||
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
|
|
||||||
from colossalai.utils.memory import colo_set_process_memory_fraction
|
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
|
||||||
from colossalai.gemini.stateful_tensor import TensorState
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
from typing import List
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from colossalai.gemini import StatefulTensorMgr
|
|
||||||
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
|
|
||||||
|
|
||||||
|
|
||||||
class Net(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
# each parameter is 128 MB
|
|
||||||
self.p0 = Parameter(torch.empty(1024, 1024, 32))
|
|
||||||
self.p1 = Parameter(torch.empty(1024, 1024, 32))
|
|
||||||
self.p2 = Parameter(torch.empty(1024, 1024, 32))
|
|
||||||
|
|
||||||
|
|
||||||
def limit_cuda_memory(memory_in_g: float):
|
|
||||||
cuda_capacity = torch.cuda.get_device_properties(get_current_device()).total_memory
|
|
||||||
fraction = (memory_in_g * 1024**3) / cuda_capacity
|
|
||||||
colo_set_process_memory_fraction(fraction)
|
|
||||||
|
|
||||||
|
|
||||||
def run_stm():
|
|
||||||
# warmup phase use 20% CUDA memory to store params
|
|
||||||
# only 2 params can be on CUDA
|
|
||||||
limit_cuda_memory(1.26)
|
|
||||||
model = Net()
|
|
||||||
for p in model.parameters():
|
|
||||||
p.colo_attr = ShardedParamV2(p, set_data_none=True)
|
|
||||||
GLOBAL_MODEL_DATA_TRACER.register_model(model)
|
|
||||||
mem_collector = MemStatsCollector()
|
|
||||||
tensor_placement_policy = AutoTensorPlacementPolicy(mem_stats_collector=mem_collector)
|
|
||||||
stateful_tensor_mgr = StatefulTensorMgr(tensor_placement_policy)
|
|
||||||
stateful_tensors = [p.colo_attr.sharded_data_tensor for p in model.parameters()]
|
|
||||||
stateful_tensor_mgr.register_stateful_tensor_list(stateful_tensors)
|
|
||||||
|
|
||||||
mem_collector.start_collection()
|
|
||||||
# Compute order: 0 1 2 0 1
|
|
||||||
# warmup
|
|
||||||
# use naive eviction strategy
|
|
||||||
apply_adjust(model, model.p0, [model.p0], stateful_tensor_mgr)
|
|
||||||
mem_collector.sample_model_data()
|
|
||||||
mem_collector.sample_overall_data()
|
|
||||||
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
|
||||||
mem_collector.sample_model_data()
|
|
||||||
mem_collector.sample_overall_data()
|
|
||||||
apply_adjust(model, model.p2, [model.p1, model.p2], stateful_tensor_mgr)
|
|
||||||
mem_collector.sample_model_data()
|
|
||||||
mem_collector.sample_overall_data()
|
|
||||||
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
|
||||||
mem_collector.sample_model_data()
|
|
||||||
mem_collector.sample_overall_data()
|
|
||||||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
|
||||||
mem_collector.sample_model_data()
|
|
||||||
mem_collector.finish_collection()
|
|
||||||
stateful_tensor_mgr.finish_iter()
|
|
||||||
|
|
||||||
# warmup done
|
|
||||||
# only 2 params can be on CUDA
|
|
||||||
limit_cuda_memory(0.26 / tensor_placement_policy._steady_cuda_cap_ratio)
|
|
||||||
# use OPT-like eviction strategy
|
|
||||||
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
|
|
||||||
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
|
||||||
apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr)
|
|
||||||
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
|
||||||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter],
|
|
||||||
stateful_tensor_mgr: StatefulTensorMgr):
|
|
||||||
compute_param.colo_attr._sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
|
||||||
for p in model.parameters():
|
|
||||||
if p is not compute_param and p.colo_attr._sharded_data_tensor.state != TensorState.HOLD:
|
|
||||||
p.colo_attr._sharded_data_tensor.trans_state(TensorState.HOLD)
|
|
||||||
stateful_tensor_mgr.adjust_layout()
|
|
||||||
print_stats(model)
|
|
||||||
device = torch.device(torch.cuda.current_device())
|
|
||||||
cuda_param_after_adjust = [hash(p) for p in cuda_param_after_adjust]
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if hash(p) in cuda_param_after_adjust:
|
|
||||||
assert p.colo_attr._sharded_data_tensor.device == device, f'{n} {p.colo_attr._sharded_data_tensor.device} vs {device}'
|
|
||||||
else:
|
|
||||||
assert p.colo_attr._sharded_data_tensor.device == torch.device('cpu')
|
|
||||||
|
|
||||||
|
|
||||||
def print_stats(model: torch.nn.Module):
|
|
||||||
msgs = []
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
msgs.append(f'{n}: {p.colo_attr._sharded_data_tensor.state}({p.colo_attr._sharded_data_tensor.device})')
|
|
||||||
print(f'[ {", ".join(msgs)} ]')
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
||||||
run_stm()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
@rerun_if_address_is_in_use()
|
|
||||||
def test_stateful_tensor_manager(world_size=1):
|
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# this unit test can pass if available CUDA memory >= 1.5G
|
|
||||||
test_stateful_tensor_manager()
|
|
|
@ -3,23 +3,22 @@
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import colossalai
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
from common import CONFIG
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.gemini.memory_tracer.utils import colo_model_mem_usage
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.gemini.memory_tracer.model_data_memtracer import \
|
|
||||||
colo_model_mem_usage
|
|
||||||
from colossalai.utils.memory import colo_device_memory_used
|
from colossalai.utils.memory import colo_device_memory_used
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
from common import CONFIG
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize("init_device_type", ['cpu', 'cuda'])
|
@parameterize("init_device_type", ['cpu', 'cuda'])
|
||||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
|
|
Loading…
Reference in New Issue