[Gemini] update API of the chunkmemstatscollector. (#2129)

pull/2130/head^2
Jiarui Fang 2022-12-14 00:47:06 +08:00 committed by GitHub
parent 2938edf446
commit c89c66a858
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 32 additions and 163 deletions

View File

@ -55,7 +55,7 @@ class GeminiManager:
get the memory statistics during training.
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
Note, for the latter you can not access the memstats before warmup iteration finishes.
Note, for the latter, you can not access the memstats before warmup iteration finishes.
"""
if self._premade_memstats_:
return self._memstats

View File

@ -11,18 +11,25 @@ from .memstats_collector import MemStatsCollector
class ChunkMemStatsCollector(MemStatsCollector):
def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
"""
Memory Statistic Collector for Chunks.
Args:
chunk_manager (ChunkManager): the chunk manager.
memstats (Optional[MemStats], optional): memory statistics collected by RMT. Defaults to None.
"""
super().__init__(memstats)
self._chunk_manager = chunk_manager
# override
def record_model_data_volume(self) -> None:
"""Sampling model data statistics.
"""
record model data volumn on cuda and cpu.
"""
if self._start_flag and not self.use_outside_memstats:
cuda_mem = self._chunk_manager.total_mem['cuda']
cpu_mem = self._chunk_manager.total_mem['cpu']
self._memstats.append_model_data('cuda', cuda_mem)
self._memstats.append_model_data('cpu', cpu_mem)
self._memstats.record_max_cuda_model_data(cuda_mem)
@property
def cuda_margin_mem(self) -> float:

View File

@ -22,6 +22,7 @@ class MemStats(object):
self._preop_step = 0
self._prev_overall_cuda = -1
self._max_overall_cuda = 0
self._prev_md_cuda = -1
# old version
@ -46,6 +47,11 @@ class MemStats(object):
def record_max_cuda_overall_data(self, val):
self._prev_overall_cuda = val
self._max_overall_cuda = max(self._max_overall_cuda, val)
@property
def max_overall_cuda(self):
return self._max_overall_cuda
def increase_preop_step(self, param_list: List[torch.nn.Parameter]):
"""
@ -85,67 +91,6 @@ class MemStats(object):
else:
return self._param_runtime_order
## APIs to be depracated
def append_overall_data(self, device_type: str, val: float):
if device_type == 'cuda':
self._overall_cuda_list.append(val)
elif device_type == 'cpu':
self._overall_cpu_list.append(val)
else:
raise TypeError
def append_model_data(self, device_type: str, val: float):
if device_type == 'cuda':
self._model_data_cuda_list.append(val)
elif device_type == 'cpu':
self._model_data_cpu_list.append(val)
else:
raise TypeError
def last_model_data(self, device_type: str):
if len(self._model_data_cuda_list) == 0:
return None
if device_type == 'cuda':
return self._model_data_cuda_list[-1]
elif device_type == 'cpu':
return self._model_data_cpu_list[-1]
else:
raise TypeError
def append_non_model_data(self, device_type: str, val=None):
if device_type == 'cuda':
if val is None:
if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0:
return
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
else:
self._non_model_data_cuda_list.append(val)
elif device_type == 'cpu':
if val is None:
if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0:
return
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
else:
self._non_model_data_cuda_list.append(val)
else:
raise TypeError
def overall_mem_stats(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._overall_cuda_list
elif device_type == 'cpu':
return self._overall_cpu_list
else:
raise TypeError
def model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._model_data_cuda_list
elif device_type == 'cpu':
return self._model_data_cpu_list
else:
raise TypeError
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._non_model_data_cuda_list

View File

@ -59,6 +59,7 @@ class MemStatsCollector:
return [t - self._sampling_time[0] for t in self._sampling_time]
def start_collection(self):
print('start collection')
self._start_flag = True
self._mem_monitor.start()
@ -68,31 +69,24 @@ class MemStatsCollector:
self._step_total = len(self._memstats.non_model_data_list('cuda'))
self._start_flag = False
self._mem_monitor.finish()
print(f'finish_collection {self._step_total}')
# deprecated
def record_model_data_volume(self) -> None:
"""Sampling model data statistics.
"""
if self._start_flag and not self.use_outside_memstats:
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
self._memstats.append_model_data('cuda', cuda_mem)
self._memstats.append_model_data('cpu', cpu_mem)
raise NotImplementedError("MemStatsCollector has not implemented record_model_data_volume")
def sample_overall_data(self) -> None:
"""Sampling non model data statistics.
"""
Sampling overall and non model data cuda memory statistics.
"""
if self._start_flag and not self.use_outside_memstats:
# overall data recording is after model data recording
if len(self._memstats._model_data_cuda_list) == 0:
return
cuda_overall = self._mem_monitor.finish()
self._memstats.record_max_cuda_overall_data(cuda_overall)
self._memstats.calc_max_cuda_non_model_data()
self._memstats.append_overall_data('cuda', self._mem_monitor.finish())
self._memstats.append_overall_data('cpu', colo_device_memory_used(torch.device('cpu')))
assert len(self._memstats._model_data_cuda_list) == len(self._memstats._overall_cuda_list)
self._memstats.append_non_model_data('cuda')
self._memstats.append_non_model_data('cpu')
self._mem_monitor.start()
if self._start_flag:

View File

@ -206,7 +206,6 @@ class ShardedModelV2(nn.Module):
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
f.write('CUDA model data (GB)\n')
f.write(str(self._memstats_collector._memstats.model_data_list('cuda')))
f.write('\n')
f.write('CUDA non model data (GB)\n')
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
@ -256,8 +255,8 @@ class ShardedModelV2(nn.Module):
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
self._memstats_collector._memstats.overall_mem_stats('cuda'))
self._cuda_margin_space = colo_device_memory_capacity(
get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
@torch.no_grad()
def _post_backward_operations(self) -> None:

View File

@ -32,6 +32,8 @@ class GeminiZeROHook(ColoParamOpHook):
self._gemini_manager.adjust_layout(chunks)
for chunk in chunks:
self._chunk_manager.access_chunk(chunk)
# record cuda model data of the current OP
self._gemini_manager.record_model_data_volume()
def post_op(self, params):

View File

@ -57,11 +57,10 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
if model_name == 'repeated_computed_layers':
for idx, p in enumerate(model.parameters()):
step_list = memstats.param_used_timestep(p)
step_list = memstats.param_used_step(p)
if idx < 4:
assert len(step_list) == 4
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000

View File

@ -1,77 +0,0 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import colossalai
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
class MyTestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.proj1 = nn.Linear(512, 512)
self.weight = nn.Parameter(torch.randn(1024, 512))
self.proj2 = nn.Linear(1024, 512)
def forward(self, x):
x = self.proj1(x)
x = F.linear(x, self.weight)
x = self.proj2(x)
return x
def run_mem_collector_testing():
cuda_capacity = colo_device_memory_capacity(get_current_device())
fraction = (50 * 1024**2) / cuda_capacity
# limit max memory to 50MB
colo_set_process_memory_fraction(fraction)
shard_strategy = BucketTensorShardStrategy()
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
model = MyTestModel()
model = ShardedModelV2(module=model,
shard_strategy=shard_strategy,
reduce_scatter_bucket_size_mb=1,
tensor_placement_policy='auto')
data = torch.randn(2, 512, device=get_current_device())
output = model(data)
loss = torch.mean(output)
model.backward(loss)
cuda_model_data_list = model._memstats_collector._memstats.model_data_list('cuda')
assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032]
cuda_non_model_data_list = model._memstats_collector._memstats.non_model_data_list('cuda')
print('cuda_non_model_data_list ', cuda_non_model_data_list)
assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1]
assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1]
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_mem_collector_testing()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_mem_collector(world_size=2):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_mem_collector()