mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] gemini use the runtime memory tracer (RMT) (#2099)
parent
2bf2d1cd3b
commit
85efb7ac2e
|
@ -5,8 +5,9 @@ from typing import List, Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||||
|
from colossalai.gemini.memory_tracer import MemStats
|
||||||
|
|
||||||
from .memory_tracer import ChunkMemStatsCollector, StaticMemStatsCollector
|
from .memory_tracer import ChunkMemStatsCollector
|
||||||
from .placement_policy import PlacementPolicyFactory
|
from .placement_policy import PlacementPolicyFactory
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,13 +27,14 @@ class GeminiManager:
|
||||||
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
|
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
||||||
|
|
||||||
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
||||||
self.policy_name = placement_policy
|
self.policy_name = placement_policy
|
||||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||||
self._chunk_manager = chunk_manager
|
self._chunk_manager = chunk_manager
|
||||||
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
|
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
|
||||||
|
memstats) if policy_cls.need_mem_stats else None
|
||||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
||||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||||
self._compute_idx: int = -1
|
self._compute_idx: int = -1
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
|
from .memory_stats import MemStats # isort:skip
|
||||||
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
|
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
|
||||||
from .memstats_collector import MemStatsCollector # isort:skip
|
from .memstats_collector import MemStatsCollector # isort:skip
|
||||||
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
|
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
|
||||||
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
|
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
|
||||||
from .memory_stats import MemStats
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from colossalai.gemini.chunk import ChunkManager
|
from colossalai.gemini.chunk import ChunkManager
|
||||||
|
from colossalai.gemini.memory_tracer import MemStats
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.memory import colo_device_memory_capacity
|
from colossalai.utils.memory import colo_device_memory_capacity
|
||||||
|
|
||||||
|
@ -7,15 +10,15 @@ from .memstats_collector import MemStatsCollector
|
||||||
|
|
||||||
class ChunkMemStatsCollector(MemStatsCollector):
|
class ChunkMemStatsCollector(MemStatsCollector):
|
||||||
|
|
||||||
def __init__(self, chunk_manager: ChunkManager) -> None:
|
def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
||||||
super().__init__()
|
super().__init__(memstats)
|
||||||
self._chunk_manager = chunk_manager
|
self._chunk_manager = chunk_manager
|
||||||
|
|
||||||
# override
|
# override
|
||||||
def sample_model_data(self) -> None:
|
def sample_model_data(self) -> None:
|
||||||
"""Sampling model data statistics.
|
"""Sampling model data statistics.
|
||||||
"""
|
"""
|
||||||
if self._start_flag:
|
if self._start_flag and not self.use_outside_memstats:
|
||||||
cuda_mem = self._chunk_manager.total_mem['cuda']
|
cuda_mem = self._chunk_manager.total_mem['cuda']
|
||||||
cpu_mem = self._chunk_manager.total_mem['cpu']
|
cpu_mem = self._chunk_manager.total_mem['cpu']
|
||||||
self._memstats.append_model_data('cuda', cuda_mem)
|
self._memstats.append_model_data('cuda', cuda_mem)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -22,14 +22,19 @@ class MemStatsCollector:
|
||||||
It has a Sampling counter which is reset after DNN training iteration.
|
It has a Sampling counter which is reset after DNN training iteration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, memstats: Optional[MemStats] = None) -> None:
|
||||||
self._mem_monitor = SyncCudaMemoryMonitor()
|
self._mem_monitor = SyncCudaMemoryMonitor()
|
||||||
self._sampling_time = []
|
self._sampling_time = []
|
||||||
|
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
self._step_idx = 0
|
self._step_idx = 0
|
||||||
self._step_total = 0
|
self._step_total = 0
|
||||||
self._memstats = MemStats()
|
if memstats is not None:
|
||||||
|
self.use_outside_memstats = True
|
||||||
|
self._memstats = memstats
|
||||||
|
else:
|
||||||
|
self.use_outside_memstats = False
|
||||||
|
self._memstats = MemStats()
|
||||||
|
|
||||||
def next_period_non_model_data_usage(self, device_type: str) -> int:
|
def next_period_non_model_data_usage(self, device_type: str) -> int:
|
||||||
"""Get max non model data memory usage of current sampling period
|
"""Get max non model data memory usage of current sampling period
|
||||||
|
@ -63,7 +68,7 @@ class MemStatsCollector:
|
||||||
def sample_model_data(self) -> None:
|
def sample_model_data(self) -> None:
|
||||||
"""Sampling model data statistics.
|
"""Sampling model data statistics.
|
||||||
"""
|
"""
|
||||||
if self._start_flag:
|
if self._start_flag and not self.use_outside_memstats:
|
||||||
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||||
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
||||||
self._memstats.append_model_data('cuda', cuda_mem)
|
self._memstats.append_model_data('cuda', cuda_mem)
|
||||||
|
@ -72,7 +77,7 @@ class MemStatsCollector:
|
||||||
def sample_overall_data(self) -> None:
|
def sample_overall_data(self) -> None:
|
||||||
"""Sampling non model data statistics.
|
"""Sampling non model data statistics.
|
||||||
"""
|
"""
|
||||||
if self._start_flag:
|
if self._start_flag and not self.use_outside_memstats:
|
||||||
# overall data recording is after model data recording
|
# overall data recording is after model data recording
|
||||||
if len(self._memstats._model_data_cuda_list) == 0:
|
if len(self._memstats._model_data_cuda_list) == 0:
|
||||||
return
|
return
|
||||||
|
@ -84,9 +89,11 @@ class MemStatsCollector:
|
||||||
|
|
||||||
self._memstats.append_non_model_data('cuda')
|
self._memstats.append_non_model_data('cuda')
|
||||||
self._memstats.append_non_model_data('cpu')
|
self._memstats.append_non_model_data('cpu')
|
||||||
self._sampling_time.append(time.time())
|
|
||||||
self._mem_monitor.start()
|
self._mem_monitor.start()
|
||||||
|
|
||||||
|
if self._start_flag:
|
||||||
|
self._sampling_time.append(time.time())
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self._memstats.clear()
|
self._memstats.clear()
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
|
|
|
@ -35,6 +35,9 @@ class RuntimeMemTracer():
|
||||||
|
|
||||||
self._cast_buffers_to_cuda_dtype()
|
self._cast_buffers_to_cuda_dtype()
|
||||||
|
|
||||||
|
def memstats(self):
|
||||||
|
return self._memstats
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.forward(*args, **kwargs)
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||||
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
|
from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
|
||||||
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
|
from colossalai.tensor import ProcessGroup
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
from tests.components_to_test import run_fwd_bwd
|
||||||
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
from tests.test_tensor.common_utils import set_seed
|
||||||
|
|
||||||
|
# run gemini use the runtime memory tracer
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('placement_policy', ['auto'])
|
||||||
|
@parameterize('keep_gather', [False])
|
||||||
|
@parameterize('model_name', ['bert', 'albert', 'gpt2'])
|
||||||
|
@parameterize('use_grad_checkpoint', [False, True])
|
||||||
|
def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False):
|
||||||
|
set_seed(42)
|
||||||
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
|
with ColoInitContext(device='cpu'):
|
||||||
|
model = model_builder(use_grad_checkpoint)
|
||||||
|
|
||||||
|
print(f'model_name {model_name}')
|
||||||
|
runtime_mem_tracer = RuntimeMemTracer(model)
|
||||||
|
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||||
|
if i > 0:
|
||||||
|
break
|
||||||
|
input_ids, label = input_ids.cuda(), label.cuda()
|
||||||
|
|
||||||
|
# mem tracing
|
||||||
|
if i == 0:
|
||||||
|
run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer)
|
||||||
|
memstats = runtime_mem_tracer.memstats()
|
||||||
|
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
|
||||||
|
print('runtime tracer: ', runtime_tracer_non_model_data)
|
||||||
|
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||||
|
config_dict[world_size]['chunk_size'] = 5000
|
||||||
|
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||||
|
chunk_manager = ChunkManager(config_dict)
|
||||||
|
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||||
|
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||||
|
|
||||||
|
pg = ProcessGroup()
|
||||||
|
set_seed(pg.dp_local_rank())
|
||||||
|
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||||
|
# you can only test a single fwd + bwd.
|
||||||
|
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
|
||||||
|
if i > 1:
|
||||||
|
break
|
||||||
|
input_ids, label = input_ids.cuda(), label.cuda()
|
||||||
|
|
||||||
|
set_seed(42)
|
||||||
|
loss = run_fwd_bwd(model, input_ids, label, criterion, model)
|
||||||
|
|
||||||
|
gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda')
|
||||||
|
|
||||||
|
# print('gemini non model data:', gemini_non_model_data)
|
||||||
|
|
||||||
|
assert len(gemini_non_model_data) == len(runtime_tracer_non_model_data), \
|
||||||
|
f'model_name {model_name} {len(gemini_non_model_data)} vs {len(runtime_tracer_non_model_data)}'
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
config = {}
|
||||||
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_gemini_use_rmt()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_gemini_use_rmt(world_size):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_gemini_use_rmt(1)
|
Loading…
Reference in New Issue