mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] use MemStats to store the tracing data. Seperate it from Collector. (#2084)
parent
1f99205827
commit
33f4412102
|
@ -11,15 +11,16 @@ class ChunkMemStatsCollector(MemStatsCollector):
|
|||
super().__init__()
|
||||
self._chunk_manager = chunk_manager
|
||||
|
||||
# override
|
||||
def sample_model_data(self) -> None:
|
||||
"""Sampling model data statistics.
|
||||
"""
|
||||
if self._start_flag:
|
||||
cuda_mem = self._chunk_manager.total_mem['cuda']
|
||||
cpu_mem = self._chunk_manager.total_mem['cpu']
|
||||
self._model_data_cuda_list.append(cuda_mem)
|
||||
self._model_data_cpu_list.append(cpu_mem)
|
||||
self._memstats.append_model_data('cuda', cuda_mem)
|
||||
self._memstats.append_model_data('cpu', cpu_mem)
|
||||
|
||||
@property
|
||||
def cuda_margin_mem(self) -> float:
|
||||
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
|
||||
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda')
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class MemStats(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Store the non model data statistics used for Gemini and ZeroOptimizer.
|
||||
"""
|
||||
# p -> list of non_model data volumn visied in order.
|
||||
self.param_non_model_data_map: Dict(Any, List[int]) = {}
|
||||
|
||||
self._model_data_cuda_list = []
|
||||
self._model_data_cpu_list = []
|
||||
|
||||
self._overall_cuda_list = []
|
||||
self._overall_cpu_list = []
|
||||
|
||||
self._non_model_data_cuda_list = []
|
||||
self._non_model_data_cpu_list = []
|
||||
|
||||
def append_overall_data(self, device_type: str, val: float):
|
||||
if device_type == 'cuda':
|
||||
self._overall_cuda_list.append(val)
|
||||
elif device_type == 'cpu':
|
||||
self._overall_cpu_list.append(val)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def append_model_data(self, device_type: str, val: float):
|
||||
if device_type == 'cuda':
|
||||
self._model_data_cuda_list.append(val)
|
||||
elif device_type == 'cpu':
|
||||
self._model_data_cpu_list.append(val)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def append_non_model_data(self, device_type: str):
|
||||
if device_type == 'cuda':
|
||||
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
|
||||
elif device_type == 'cpu':
|
||||
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def overall_mem_stats(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._overall_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
return self._overall_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def model_data_list(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._model_data_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
return self._model_data_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def non_model_data_list(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._non_model_data_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
return self._non_model_data_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def max_non_model_data(self, device_type: str) -> float:
|
||||
if device_type == 'cuda':
|
||||
return max(self._non_model_data_cuda_list)
|
||||
elif device_type == 'cpu':
|
||||
return max(self._non_model_data_cpu_list)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def max_overall_cuda(self, device_type: str) -> float:
|
||||
if device_type == 'cuda':
|
||||
return max(self._overall_cuda_list)
|
||||
elif device_type == 'cpu':
|
||||
return max(self._overall_cpu_list)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def clear(self):
|
||||
self._model_data_cuda_list = []
|
||||
self._overall_cuda_list = []
|
||||
|
||||
self._model_data_cpu_list = []
|
||||
self._overall_cpu_list = []
|
||||
|
||||
self._non_model_data_cpu_list = []
|
||||
self._non_model_data_cuda_list = []
|
|
@ -7,6 +7,8 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
|||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.utils.memory import colo_device_memory_used
|
||||
|
||||
from .memory_stats import MemStats
|
||||
|
||||
|
||||
class MemStatsCollector:
|
||||
"""
|
||||
|
@ -22,43 +24,12 @@ class MemStatsCollector:
|
|||
|
||||
def __init__(self) -> None:
|
||||
self._mem_monitor = SyncCudaMemoryMonitor()
|
||||
self._model_data_cuda_list = []
|
||||
self._overall_cuda_list = []
|
||||
|
||||
self._model_data_cpu_list = []
|
||||
self._overall_cpu_list = []
|
||||
|
||||
self._non_model_data_cuda_list = []
|
||||
self._non_model_data_cpu_list = []
|
||||
self._sampling_time = []
|
||||
|
||||
self._start_flag = False
|
||||
self._step_idx = 0
|
||||
self._step_total = 0
|
||||
|
||||
def overall_mem_stats(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._overall_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
return self._overall_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def model_data_list(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._model_data_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
return self._model_data_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def non_model_data_list(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._non_model_data_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
return self._non_model_data_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
self._memstats = MemStats()
|
||||
|
||||
def next_period_non_model_data_usage(self, device_type: str) -> int:
|
||||
"""Get max non model data memory usage of current sampling period
|
||||
|
@ -71,7 +42,7 @@ class MemStatsCollector:
|
|||
"""
|
||||
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
||||
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
|
||||
next_non_model_data = self.non_model_data_list(device_type)[self._step_idx]
|
||||
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
|
||||
self._step_idx = (self._step_idx + 1) % self._step_total
|
||||
return next_non_model_data
|
||||
|
||||
|
@ -95,37 +66,29 @@ class MemStatsCollector:
|
|||
if self._start_flag:
|
||||
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
||||
self._model_data_cuda_list.append(cuda_mem)
|
||||
self._model_data_cpu_list.append(cpu_mem)
|
||||
self._memstats.append_model_data('cuda', cuda_mem)
|
||||
self._memstats.append_model_data('cpu', cpu_mem)
|
||||
|
||||
def sample_overall_data(self) -> None:
|
||||
"""Sampling non model data statistics.
|
||||
"""
|
||||
if self._start_flag:
|
||||
# overall data recording is after model data recording
|
||||
if len(self._model_data_cuda_list) == 0:
|
||||
if len(self._memstats._model_data_cuda_list) == 0:
|
||||
return
|
||||
|
||||
self._overall_cuda_list.append(self._mem_monitor.finish())
|
||||
self._overall_cpu_list.append(colo_device_memory_used(torch.device('cpu')))
|
||||
self._memstats.append_overall_data('cuda', self._mem_monitor.finish())
|
||||
self._memstats.append_overall_data('cpu', colo_device_memory_used(torch.device('cpu')))
|
||||
|
||||
assert len(self._model_data_cuda_list) == len(self._overall_cuda_list)
|
||||
assert len(self._memstats._model_data_cuda_list) == len(self._memstats._overall_cuda_list)
|
||||
|
||||
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
|
||||
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||
self._memstats.append_non_model_data('cuda')
|
||||
self._memstats.append_non_model_data('cpu')
|
||||
self._sampling_time.append(time.time())
|
||||
self._mem_monitor.start()
|
||||
|
||||
def clear(self) -> None:
|
||||
self._model_data_cuda_list = []
|
||||
self._overall_cuda_list = []
|
||||
|
||||
self._model_data_cpu_list = []
|
||||
self._overall_cpu_list = []
|
||||
|
||||
self._non_model_data_cpu_list = []
|
||||
self._non_model_data_cuda_list = []
|
||||
|
||||
self._memstats.clear()
|
||||
self._start_flag = False
|
||||
self._step_idx = 0
|
||||
self._step_total = 0
|
||||
|
|
|
@ -85,7 +85,6 @@ class ShardedModelV2(nn.Module):
|
|||
tensor_placement_policy: str = 'cuda',
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
reuse_fp16_shard: bool = False,
|
||||
user_static_memstats: bool = False,
|
||||
*args,
|
||||
**kwargs):
|
||||
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
||||
|
@ -119,14 +118,10 @@ class ShardedModelV2(nn.Module):
|
|||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.rank = dist.get_rank(self.process_group)
|
||||
self.shard_strategy = shard_strategy
|
||||
self.user_static_memstats = user_static_memstats
|
||||
|
||||
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
||||
if self._use_memory_tracer:
|
||||
if self.user_static_memstats:
|
||||
self._memstats_collector = StaticMemStatsCollector(self.module)
|
||||
else:
|
||||
self._memstats_collector = MemStatsCollector()
|
||||
self._memstats_collector = MemStatsCollector()
|
||||
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
||||
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
||||
else:
|
||||
|
@ -211,19 +206,17 @@ class ShardedModelV2(nn.Module):
|
|||
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
|
||||
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
|
||||
f.write('CUDA model data (GB)\n')
|
||||
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB')))
|
||||
f.write(str(self._memstats_collector._memstats.model_data_list('cuda')))
|
||||
f.write('\n')
|
||||
f.write('CUDA non model data (GB)\n')
|
||||
f.write(str(self._memstats_collector.non_model_data_list('cuda', 'GB')))
|
||||
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
|
||||
f.write('CPU non model data (GB)\n')
|
||||
f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB')))
|
||||
f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu')))
|
||||
f.write('\n')
|
||||
|
||||
def _pre_forward_operations(self, *args):
|
||||
# the operation will affect the memory tracer behavior in ZeroHook
|
||||
if self._memstats_collector:
|
||||
if self.user_static_memstats:
|
||||
self.init_mem_stats(*args)
|
||||
self._start_collect_memstats()
|
||||
|
||||
for p in self.module.parameters():
|
||||
|
@ -264,7 +257,7 @@ class ShardedModelV2(nn.Module):
|
|||
# model data is fixed in cuda during training.
|
||||
# cuda margin space can be used to store OS.
|
||||
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
|
||||
self._memstats_collector.overall_mem_stats('cuda'))
|
||||
self._memstats_collector._memstats.overall_mem_stats('cuda'))
|
||||
|
||||
@torch.no_grad()
|
||||
def _post_backward_operations(self) -> None:
|
||||
|
|
|
@ -1,74 +1,77 @@
|
|||
import torch
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from functools import partial
|
||||
|
||||
|
||||
class MyTestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.proj1 = nn.Linear(512, 512)
|
||||
self.weight = nn.Parameter(torch.randn(1024, 512))
|
||||
self.proj2 = nn.Linear(1024, 512)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj1(x)
|
||||
x = F.linear(x, self.weight)
|
||||
x = self.proj2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def run_mem_collector_testing():
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
fraction = (50 * 1024**2) / cuda_capacity
|
||||
# limit max memory to 50MB
|
||||
colo_set_process_memory_fraction(fraction)
|
||||
shard_strategy = BucketTensorShardStrategy()
|
||||
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
|
||||
model = MyTestModel()
|
||||
|
||||
model = ShardedModelV2(module=model,
|
||||
shard_strategy=shard_strategy,
|
||||
reduce_scatter_bucket_size_mb=1,
|
||||
tensor_placement_policy='auto')
|
||||
|
||||
data = torch.randn(2, 512, device=get_current_device())
|
||||
|
||||
output = model(data)
|
||||
loss = torch.mean(output)
|
||||
model.backward(loss)
|
||||
|
||||
cuda_model_data_list = model._memstats_collector.model_data_list('cuda')
|
||||
assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032]
|
||||
|
||||
cuda_non_model_data_list = model._memstats_collector.non_model_data_list('cuda')
|
||||
assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1]
|
||||
assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1]
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_mem_collector_testing()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mem_collector(world_size=2):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mem_collector()
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
|
||||
|
||||
class MyTestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.proj1 = nn.Linear(512, 512)
|
||||
self.weight = nn.Parameter(torch.randn(1024, 512))
|
||||
self.proj2 = nn.Linear(1024, 512)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj1(x)
|
||||
x = F.linear(x, self.weight)
|
||||
x = self.proj2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def run_mem_collector_testing():
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
fraction = (50 * 1024**2) / cuda_capacity
|
||||
# limit max memory to 50MB
|
||||
colo_set_process_memory_fraction(fraction)
|
||||
shard_strategy = BucketTensorShardStrategy()
|
||||
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
|
||||
model = MyTestModel()
|
||||
|
||||
model = ShardedModelV2(module=model,
|
||||
shard_strategy=shard_strategy,
|
||||
reduce_scatter_bucket_size_mb=1,
|
||||
tensor_placement_policy='auto')
|
||||
|
||||
data = torch.randn(2, 512, device=get_current_device())
|
||||
|
||||
output = model(data)
|
||||
loss = torch.mean(output)
|
||||
model.backward(loss)
|
||||
|
||||
cuda_model_data_list = model._memstats_collector._memstats.model_data_list('cuda')
|
||||
assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032]
|
||||
|
||||
cuda_non_model_data_list = model._memstats_collector._memstats.non_model_data_list('cuda')
|
||||
print('cuda_non_model_data_list ', cuda_non_model_data_list)
|
||||
assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1]
|
||||
assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1]
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_mem_collector_testing()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mem_collector(world_size=2):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mem_collector()
|
||||
|
|
Loading…
Reference in New Issue