[Gemini] use MemStats to store the tracing data. Seperate it from Collector. (#2084)

pull/2086/head
Jiarui Fang 2022-12-06 16:43:06 +08:00 committed by GitHub
parent 1f99205827
commit 33f4412102
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 193 additions and 139 deletions

View File

@ -11,15 +11,16 @@ class ChunkMemStatsCollector(MemStatsCollector):
super().__init__()
self._chunk_manager = chunk_manager
# override
def sample_model_data(self) -> None:
"""Sampling model data statistics.
"""
if self._start_flag:
cuda_mem = self._chunk_manager.total_mem['cuda']
cpu_mem = self._chunk_manager.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem)
self._memstats.append_model_data('cuda', cuda_mem)
self._memstats.append_model_data('cpu', cpu_mem)
@property
def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda')

View File

@ -0,0 +1,94 @@
from typing import Any, Dict, List
class MemStats(object):
def __init__(self) -> None:
"""
Store the non model data statistics used for Gemini and ZeroOptimizer.
"""
# p -> list of non_model data volumn visied in order.
self.param_non_model_data_map: Dict(Any, List[int]) = {}
self._model_data_cuda_list = []
self._model_data_cpu_list = []
self._overall_cuda_list = []
self._overall_cpu_list = []
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
def append_overall_data(self, device_type: str, val: float):
if device_type == 'cuda':
self._overall_cuda_list.append(val)
elif device_type == 'cpu':
self._overall_cpu_list.append(val)
else:
raise TypeError
def append_model_data(self, device_type: str, val: float):
if device_type == 'cuda':
self._model_data_cuda_list.append(val)
elif device_type == 'cpu':
self._model_data_cpu_list.append(val)
else:
raise TypeError
def append_non_model_data(self, device_type: str):
if device_type == 'cuda':
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
elif device_type == 'cpu':
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
else:
raise TypeError
def overall_mem_stats(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._overall_cuda_list
elif device_type == 'cpu':
return self._overall_cpu_list
else:
raise TypeError
def model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._model_data_cuda_list
elif device_type == 'cpu':
return self._model_data_cpu_list
else:
raise TypeError
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._non_model_data_cuda_list
elif device_type == 'cpu':
return self._non_model_data_cpu_list
else:
raise TypeError
def max_non_model_data(self, device_type: str) -> float:
if device_type == 'cuda':
return max(self._non_model_data_cuda_list)
elif device_type == 'cpu':
return max(self._non_model_data_cpu_list)
else:
raise TypeError
def max_overall_cuda(self, device_type: str) -> float:
if device_type == 'cuda':
return max(self._overall_cuda_list)
elif device_type == 'cpu':
return max(self._overall_cpu_list)
else:
raise TypeError
def clear(self):
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cpu_list = []
self._non_model_data_cuda_list = []

View File

@ -7,6 +7,8 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.utils.memory import colo_device_memory_used
from .memory_stats import MemStats
class MemStatsCollector:
"""
@ -22,43 +24,12 @@ class MemStatsCollector:
def __init__(self) -> None:
self._mem_monitor = SyncCudaMemoryMonitor()
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
self._sampling_time = []
self._start_flag = False
self._step_idx = 0
self._step_total = 0
def overall_mem_stats(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._overall_cuda_list
elif device_type == 'cpu':
return self._overall_cpu_list
else:
raise TypeError
def model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._model_data_cuda_list
elif device_type == 'cpu':
return self._model_data_cpu_list
else:
raise TypeError
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._non_model_data_cuda_list
elif device_type == 'cpu':
return self._non_model_data_cpu_list
else:
raise TypeError
self._memstats = MemStats()
def next_period_non_model_data_usage(self, device_type: str) -> int:
"""Get max non model data memory usage of current sampling period
@ -71,7 +42,7 @@ class MemStatsCollector:
"""
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
next_non_model_data = self.non_model_data_list(device_type)[self._step_idx]
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
self._step_idx = (self._step_idx + 1) % self._step_total
return next_non_model_data
@ -95,37 +66,29 @@ class MemStatsCollector:
if self._start_flag:
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem)
self._memstats.append_model_data('cuda', cuda_mem)
self._memstats.append_model_data('cpu', cpu_mem)
def sample_overall_data(self) -> None:
"""Sampling non model data statistics.
"""
if self._start_flag:
# overall data recording is after model data recording
if len(self._model_data_cuda_list) == 0:
if len(self._memstats._model_data_cuda_list) == 0:
return
self._overall_cuda_list.append(self._mem_monitor.finish())
self._overall_cpu_list.append(colo_device_memory_used(torch.device('cpu')))
self._memstats.append_overall_data('cuda', self._mem_monitor.finish())
self._memstats.append_overall_data('cpu', colo_device_memory_used(torch.device('cpu')))
assert len(self._model_data_cuda_list) == len(self._overall_cuda_list)
assert len(self._memstats._model_data_cuda_list) == len(self._memstats._overall_cuda_list)
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
self._memstats.append_non_model_data('cuda')
self._memstats.append_non_model_data('cpu')
self._sampling_time.append(time.time())
self._mem_monitor.start()
def clear(self) -> None:
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cpu_list = []
self._non_model_data_cuda_list = []
self._memstats.clear()
self._start_flag = False
self._step_idx = 0
self._step_total = 0

View File

@ -85,7 +85,6 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False,
user_static_memstats: bool = False,
*args,
**kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
@ -119,14 +118,10 @@ class ShardedModelV2(nn.Module):
self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group)
self.shard_strategy = shard_strategy
self.user_static_memstats = user_static_memstats
self._use_memory_tracer = tensor_placement_policy == 'auto'
if self._use_memory_tracer:
if self.user_static_memstats:
self._memstats_collector = StaticMemStatsCollector(self.module)
else:
self._memstats_collector = MemStatsCollector()
self._memstats_collector = MemStatsCollector()
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
else:
@ -211,19 +206,17 @@ class ShardedModelV2(nn.Module):
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
f.write('CUDA model data (GB)\n')
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB')))
f.write(str(self._memstats_collector._memstats.model_data_list('cuda')))
f.write('\n')
f.write('CUDA non model data (GB)\n')
f.write(str(self._memstats_collector.non_model_data_list('cuda', 'GB')))
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
f.write('CPU non model data (GB)\n')
f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB')))
f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu')))
f.write('\n')
def _pre_forward_operations(self, *args):
# the operation will affect the memory tracer behavior in ZeroHook
if self._memstats_collector:
if self.user_static_memstats:
self.init_mem_stats(*args)
self._start_collect_memstats()
for p in self.module.parameters():
@ -264,7 +257,7 @@ class ShardedModelV2(nn.Module):
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
self._memstats_collector.overall_mem_stats('cuda'))
self._memstats_collector._memstats.overall_mem_stats('cuda'))
@torch.no_grad()
def _post_backward_operations(self) -> None:

View File

@ -1,17 +1,19 @@
import torch
import colossalai
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import colossalai
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from functools import partial
from colossalai.zero.sharded_model import ShardedModelV2
class MyTestModel(torch.nn.Module):
@ -50,10 +52,11 @@ def run_mem_collector_testing():
loss = torch.mean(output)
model.backward(loss)
cuda_model_data_list = model._memstats_collector.model_data_list('cuda')
cuda_model_data_list = model._memstats_collector._memstats.model_data_list('cuda')
assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032]
cuda_non_model_data_list = model._memstats_collector.non_model_data_list('cuda')
cuda_non_model_data_list = model._memstats_collector._memstats.non_model_data_list('cuda')
print('cuda_non_model_data_list ', cuda_non_model_data_list)
assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1]
assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1]