mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] update API of the chunkmemstatscollector. (#2129)
parent
2938edf446
commit
c89c66a858
|
@ -55,7 +55,7 @@ class GeminiManager:
|
|||
|
||||
get the memory statistics during training.
|
||||
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
|
||||
Note, for the latter, you can not access the memstats before warmup iteration finishes.
|
||||
Note, for the latter, you can not access the memstats before warmup iteration finishes.
|
||||
"""
|
||||
if self._premade_memstats_:
|
||||
return self._memstats
|
||||
|
|
|
@ -11,18 +11,25 @@ from .memstats_collector import MemStatsCollector
|
|||
class ChunkMemStatsCollector(MemStatsCollector):
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
||||
"""
|
||||
|
||||
Memory Statistic Collector for Chunks.
|
||||
|
||||
Args:
|
||||
chunk_manager (ChunkManager): the chunk manager.
|
||||
memstats (Optional[MemStats], optional): memory statistics collected by RMT. Defaults to None.
|
||||
"""
|
||||
super().__init__(memstats)
|
||||
self._chunk_manager = chunk_manager
|
||||
|
||||
# override
|
||||
def record_model_data_volume(self) -> None:
|
||||
"""Sampling model data statistics.
|
||||
"""
|
||||
record model data volumn on cuda and cpu.
|
||||
"""
|
||||
if self._start_flag and not self.use_outside_memstats:
|
||||
cuda_mem = self._chunk_manager.total_mem['cuda']
|
||||
cpu_mem = self._chunk_manager.total_mem['cpu']
|
||||
self._memstats.append_model_data('cuda', cuda_mem)
|
||||
self._memstats.append_model_data('cpu', cpu_mem)
|
||||
self._memstats.record_max_cuda_model_data(cuda_mem)
|
||||
|
||||
@property
|
||||
def cuda_margin_mem(self) -> float:
|
||||
|
|
|
@ -22,6 +22,7 @@ class MemStats(object):
|
|||
self._preop_step = 0
|
||||
|
||||
self._prev_overall_cuda = -1
|
||||
self._max_overall_cuda = 0
|
||||
self._prev_md_cuda = -1
|
||||
|
||||
# old version
|
||||
|
@ -46,6 +47,11 @@ class MemStats(object):
|
|||
|
||||
def record_max_cuda_overall_data(self, val):
|
||||
self._prev_overall_cuda = val
|
||||
self._max_overall_cuda = max(self._max_overall_cuda, val)
|
||||
|
||||
@property
|
||||
def max_overall_cuda(self):
|
||||
return self._max_overall_cuda
|
||||
|
||||
def increase_preop_step(self, param_list: List[torch.nn.Parameter]):
|
||||
"""
|
||||
|
@ -85,67 +91,6 @@ class MemStats(object):
|
|||
else:
|
||||
return self._param_runtime_order
|
||||
|
||||
## APIs to be depracated
|
||||
def append_overall_data(self, device_type: str, val: float):
|
||||
if device_type == 'cuda':
|
||||
self._overall_cuda_list.append(val)
|
||||
elif device_type == 'cpu':
|
||||
self._overall_cpu_list.append(val)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def append_model_data(self, device_type: str, val: float):
|
||||
if device_type == 'cuda':
|
||||
self._model_data_cuda_list.append(val)
|
||||
elif device_type == 'cpu':
|
||||
self._model_data_cpu_list.append(val)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def last_model_data(self, device_type: str):
|
||||
if len(self._model_data_cuda_list) == 0:
|
||||
return None
|
||||
if device_type == 'cuda':
|
||||
return self._model_data_cuda_list[-1]
|
||||
elif device_type == 'cpu':
|
||||
return self._model_data_cpu_list[-1]
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def append_non_model_data(self, device_type: str, val=None):
|
||||
if device_type == 'cuda':
|
||||
if val is None:
|
||||
if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0:
|
||||
return
|
||||
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
|
||||
else:
|
||||
self._non_model_data_cuda_list.append(val)
|
||||
elif device_type == 'cpu':
|
||||
if val is None:
|
||||
if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0:
|
||||
return
|
||||
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||
else:
|
||||
self._non_model_data_cuda_list.append(val)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def overall_mem_stats(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._overall_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
return self._overall_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def model_data_list(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._model_data_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
return self._model_data_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def non_model_data_list(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._non_model_data_cuda_list
|
||||
|
|
|
@ -59,6 +59,7 @@ class MemStatsCollector:
|
|||
return [t - self._sampling_time[0] for t in self._sampling_time]
|
||||
|
||||
def start_collection(self):
|
||||
print('start collection')
|
||||
self._start_flag = True
|
||||
self._mem_monitor.start()
|
||||
|
||||
|
@ -68,31 +69,24 @@ class MemStatsCollector:
|
|||
self._step_total = len(self._memstats.non_model_data_list('cuda'))
|
||||
self._start_flag = False
|
||||
self._mem_monitor.finish()
|
||||
print(f'finish_collection {self._step_total}')
|
||||
|
||||
# deprecated
|
||||
def record_model_data_volume(self) -> None:
|
||||
"""Sampling model data statistics.
|
||||
"""
|
||||
if self._start_flag and not self.use_outside_memstats:
|
||||
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
||||
self._memstats.append_model_data('cuda', cuda_mem)
|
||||
self._memstats.append_model_data('cpu', cpu_mem)
|
||||
raise NotImplementedError("MemStatsCollector has not implemented record_model_data_volume")
|
||||
|
||||
def sample_overall_data(self) -> None:
|
||||
"""Sampling non model data statistics.
|
||||
"""
|
||||
Sampling overall and non model data cuda memory statistics.
|
||||
"""
|
||||
if self._start_flag and not self.use_outside_memstats:
|
||||
# overall data recording is after model data recording
|
||||
if len(self._memstats._model_data_cuda_list) == 0:
|
||||
return
|
||||
cuda_overall = self._mem_monitor.finish()
|
||||
self._memstats.record_max_cuda_overall_data(cuda_overall)
|
||||
self._memstats.calc_max_cuda_non_model_data()
|
||||
|
||||
self._memstats.append_overall_data('cuda', self._mem_monitor.finish())
|
||||
self._memstats.append_overall_data('cpu', colo_device_memory_used(torch.device('cpu')))
|
||||
|
||||
assert len(self._memstats._model_data_cuda_list) == len(self._memstats._overall_cuda_list)
|
||||
|
||||
self._memstats.append_non_model_data('cuda')
|
||||
self._memstats.append_non_model_data('cpu')
|
||||
self._mem_monitor.start()
|
||||
|
||||
if self._start_flag:
|
||||
|
|
|
@ -206,7 +206,6 @@ class ShardedModelV2(nn.Module):
|
|||
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
|
||||
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
|
||||
f.write('CUDA model data (GB)\n')
|
||||
f.write(str(self._memstats_collector._memstats.model_data_list('cuda')))
|
||||
f.write('\n')
|
||||
f.write('CUDA non model data (GB)\n')
|
||||
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
|
||||
|
@ -256,8 +255,8 @@ class ShardedModelV2(nn.Module):
|
|||
# the way to calculate margin space is based on the assumption that
|
||||
# model data is fixed in cuda during training.
|
||||
# cuda margin space can be used to store OS.
|
||||
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
|
||||
self._memstats_collector._memstats.overall_mem_stats('cuda'))
|
||||
self._cuda_margin_space = colo_device_memory_capacity(
|
||||
get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
|
||||
|
||||
@torch.no_grad()
|
||||
def _post_backward_operations(self) -> None:
|
||||
|
|
|
@ -32,6 +32,8 @@ class GeminiZeROHook(ColoParamOpHook):
|
|||
self._gemini_manager.adjust_layout(chunks)
|
||||
for chunk in chunks:
|
||||
self._chunk_manager.access_chunk(chunk)
|
||||
|
||||
# record cuda model data of the current OP
|
||||
self._gemini_manager.record_model_data_volume()
|
||||
|
||||
def post_op(self, params):
|
||||
|
|
|
@ -57,11 +57,10 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
|||
|
||||
if model_name == 'repeated_computed_layers':
|
||||
for idx, p in enumerate(model.parameters()):
|
||||
step_list = memstats.param_used_timestep(p)
|
||||
step_list = memstats.param_used_step(p)
|
||||
if idx < 4:
|
||||
assert len(step_list) == 4
|
||||
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
|
|
|
@ -1,77 +0,0 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
|
||||
|
||||
class MyTestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.proj1 = nn.Linear(512, 512)
|
||||
self.weight = nn.Parameter(torch.randn(1024, 512))
|
||||
self.proj2 = nn.Linear(1024, 512)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj1(x)
|
||||
x = F.linear(x, self.weight)
|
||||
x = self.proj2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def run_mem_collector_testing():
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
fraction = (50 * 1024**2) / cuda_capacity
|
||||
# limit max memory to 50MB
|
||||
colo_set_process_memory_fraction(fraction)
|
||||
shard_strategy = BucketTensorShardStrategy()
|
||||
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
|
||||
model = MyTestModel()
|
||||
|
||||
model = ShardedModelV2(module=model,
|
||||
shard_strategy=shard_strategy,
|
||||
reduce_scatter_bucket_size_mb=1,
|
||||
tensor_placement_policy='auto')
|
||||
|
||||
data = torch.randn(2, 512, device=get_current_device())
|
||||
|
||||
output = model(data)
|
||||
loss = torch.mean(output)
|
||||
model.backward(loss)
|
||||
|
||||
cuda_model_data_list = model._memstats_collector._memstats.model_data_list('cuda')
|
||||
assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032]
|
||||
|
||||
cuda_non_model_data_list = model._memstats_collector._memstats.non_model_data_list('cuda')
|
||||
print('cuda_non_model_data_list ', cuda_non_model_data_list)
|
||||
assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1]
|
||||
assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1]
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_mem_collector_testing()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mem_collector(world_size=2):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mem_collector()
|
Loading…
Reference in New Issue