[zero] refactor memstats_collector (#746)

pull/756/head
HELSON 2022-04-14 12:01:12 +08:00 committed by GitHub
parent b8899e0905
commit 84c6700b2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 179 additions and 111 deletions

View File

@ -1,3 +1,4 @@
from .utils import register_ophooks_recursively, BaseOpHook
from ._memtracer_ophook import MemTracerOpHook
__all__ = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]

View File

@ -1,6 +1,6 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
from colossalai.utils.memory_tracer import SyncCudaMemoryMonitor
import torch
import time
from typing import List
@ -19,7 +19,7 @@ class MemStatsCollector:
"""
def __init__(self) -> None:
self._mem_monitor = AsyncMemoryMonitor()
self._mem_monitor = SyncCudaMemoryMonitor()
self._model_data_cuda_list = []
self._overall_cuda_list = []
@ -31,9 +31,10 @@ class MemStatsCollector:
self._sampling_time = []
self._start_flag = False
self._period_idx = 0
self._step_idx = 0
self._step_total = 0
def overall_mem_stats(self, device_type: str):
def overall_mem_stats(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._overall_cuda_list
elif device_type == 'cpu':
@ -41,47 +42,23 @@ class MemStatsCollector:
else:
raise TypeError
def model_data_list(self, device_type: str, unit: str = 'B') -> List[int]:
if unit == 'GB':
scale = 1e9
elif unit == 'MB':
scale = 1e6
elif unit == 'KB':
scale = 1e3
elif unit == 'B':
scale = 1
else:
raise TypeError
def model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return [elem / scale for elem in self._model_data_cuda_list]
return self._model_data_cuda_list
elif device_type == 'cpu':
return [elem / scale for elem in self._model_data_cpu_list]
else:
raise TypeError
def non_model_data_list(self, device_type: str, unit: str = 'B') -> List[int]:
"""Non model data stats
"""
if unit == 'GB':
scale = 1e9
elif unit == 'MB':
scale = 1e6
elif unit == 'KB':
scale = 1e3
elif unit == 'B':
scale = 1
return self._model_data_cpu_list
else:
raise TypeError
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return [elem / scale for elem in self._non_model_data_cuda_list]
return self._non_model_data_cuda_list
elif device_type == 'cpu':
return [elem / scale for elem in self._non_model_data_cpu_list]
return self._non_model_data_cpu_list
else:
raise TypeError
def max_non_model_data(self, device_type: str) -> int:
def next_period_non_model_data_usage(self, device_type: str) -> int:
"""Get max non model data memory usage of current sampling period
Args:
@ -91,12 +68,10 @@ class MemStatsCollector:
int: max non model data memory usage of current sampling period
"""
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
assert len(self._sampling_time) > 0, 'Cannot get mem stats info before collection phase.'
next_period_idx = (self._period_idx + 1) % len(self._sampling_time)
current_non_model_data = self.non_model_data_list(device_type)[self._period_idx]
next_non_model_data = self.non_model_data_list(device_type)[next_period_idx]
self._period_idx = next_period_idx
return max(current_non_model_data, next_non_model_data)
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
next_non_model_data = self.non_model_data_list(device_type)[self._step_idx]
self._step_idx = (self._step_idx + 1) % self._step_total
return next_non_model_data
@property
def sampling_time(self):
@ -107,9 +82,37 @@ class MemStatsCollector:
self._mem_monitor.start()
def finish_collection(self):
self.sample_overall_data()
self._step_total = len(self._sampling_time)
self._start_flag = False
self._mem_monitor.finish()
def sample_model_data(self) -> None:
"""Sampling model data statistics.
"""
if self._start_flag:
cuda_mem, cpu_mem = GLOBAL_MODEL_DATA_TRACER.both_mem_usage
self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem)
def sample_overall_data(self) -> None:
"""Sampling non model data statistics.
"""
if self._start_flag:
# overall data recording is after model data recording
if len(self._model_data_cuda_list) == 0:
return
self._overall_cuda_list.append(self._mem_monitor.finish())
self._overall_cpu_list.append(colo_device_memory_used(torch.device('cpu')))
assert len(self._model_data_cuda_list) == len(self._overall_cuda_list)
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
self._sampling_time.append(time.time())
self._mem_monitor.start()
def sample_memstats(self) -> None:
"""
Sampling memory statistics.
@ -119,7 +122,7 @@ class MemStatsCollector:
if self._start_flag:
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda_list.append(self._mem_monitor.finish())
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
@ -136,4 +139,5 @@ class MemStatsCollector:
self._overall_cpu_list = []
self._start_flag = False
self._period_idx = 0
self._step_idx = 0
self._step_total = 0

View File

@ -101,5 +101,9 @@ class ModelDataTracer(metaclass=SingletonMeta):
cuda_usage, _ = self._get_mem_usage()
return cuda_usage
@property
def both_mem_usage(self):
return self._get_mem_usage()
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()

View File

@ -109,6 +109,5 @@ class ShardedParamV2(object):
if self.param.grad is not None and self.param.grad.data_ptr() not in address_set:
_update_mem_use(self.param.grad)
address_set.add(self.param.grad.data_ptr())
return cuda_mem_use, cpu_mem_use

View File

@ -13,7 +13,7 @@ def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[
cuda_use, cpu_use = 0, 0
mem_use = t.numel() * t.element_size()
mem_use = t.storage().size() * t.element_size()
if t.device.type == 'cuda':
cuda_use += mem_use
elif t.device.type == 'cpu':

View File

@ -38,10 +38,6 @@ class StatefulTensorMgr(object):
def adjust_layout(self) -> None:
""" Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
Args:
mem_stats_collector (MemStatsCollector): a collector, usually owned by a Sharded Model.
It contains non-model footprint of a DNN model.
"""
# find stateful tensor in state COMPUTE
cuda_demand = 0

View File

@ -61,7 +61,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = self.mem_stats_collector.max_non_model_data('cuda')
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
if avail_cuda_model_data < cuda_demand:
@ -71,7 +71,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
freed_cuda_model_data = 0
to_free_tensor_list = hold_cuda_tensor_list
if not warmup:
next_compute_idx: Dict[StatefulTensor, int] = {t: len(compute_list) for t in hold_cuda_tensor_list}
next_compute_idx = {t: len(compute_list) for t in hold_cuda_tensor_list}
for i in range(len(compute_list) - 1, compute_idx, -1):
if compute_list[i] in next_compute_idx:
next_compute_idx[compute_list[i]] = i

View File

@ -36,17 +36,7 @@ class ZeroHook(BaseOpHook):
self._memstarts_collector = memstarts_collector
self._stateful_tensor_mgr = stateful_tensor_mgr
def pre_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
if self._stateful_tensor_mgr:
self._stateful_tensor_mgr.adjust_layout()
else:
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
def gather_parameters(self, module: torch.nn.Module):
# gather sharded parameters
if module.param_is_sharded:
tensor_list = []
@ -55,10 +45,33 @@ class ZeroHook(BaseOpHook):
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
# record memory statistics
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
def shard_parameters(self, module: torch.nn.Module):
# shard gathered parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
def adjust_module_data(self, module: torch.nn.Module):
# record overall data statistics
if self._memstarts_collector:
self._memstarts_collector.sample_overall_data()
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
# adjust stateful tensor to get enough CUDA memory
self._stateful_tensor_mgr.adjust_layout()
# record model data statistics
if self._memstarts_collector:
self._memstarts_collector.sample_model_data()
def pre_fwd_exec(self, module: torch.nn.Module, *args):
self.adjust_module_data(module)
self.gather_parameters(module)
for param in module.parameters(recurse=False):
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
@ -69,41 +82,15 @@ class ZeroHook(BaseOpHook):
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
# shard gathered parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
self.shard_parameters(module)
# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.set_data_none()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
if self._stateful_tensor_mgr:
self._stateful_tensor_mgr.adjust_layout()
else:
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
# gather sharded parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
# record memory statistics
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
self.adjust_module_data(module)
self.gather_parameters(module)
for param in module.parameters(recurse=False):
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
@ -114,13 +101,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
# shard gathered parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
self.shard_parameters(module)
# remove torch payload
for param in module.parameters(recurse=False):

View File

@ -0,0 +1,74 @@
import torch
import colossalai
import pytest
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from functools import partial
class TestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.proj1 = nn.Linear(512, 512)
self.weight = nn.Parameter(torch.randn(1024, 512))
self.proj2 = nn.Linear(1024, 512)
def forward(self, x):
x = self.proj1(x)
x = F.linear(x, self.weight)
x = self.proj2(x)
return x
def run_mem_collector_testing():
cuda_capacity = colo_device_memory_capacity(get_current_device())
fraction = (50 * 1024**2) / cuda_capacity
# limit max memory to 50MB
colo_set_process_memory_fraction(fraction)
shard_strategy = BucketTensorShardStrategy()
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
model = TestModel()
model = ShardedModelV2(module=model,
shard_strategy=shard_strategy,
reduce_scatter_bucket_size_mb=1,
tensor_placement_policy='auto')
data = torch.randn(2, 512, device=get_current_device())
output = model(data)
loss = torch.mean(output)
model.backward(loss)
cuda_model_data_list = model._memstats_collector.model_data_list('cuda')
assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032]
cuda_non_model_data_list = model._memstats_collector.non_model_data_list('cuda')
assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1]
assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1]
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_mem_collector_testing()
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_mem_collector(world_size=2):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_mem_collector()

View File

@ -48,30 +48,39 @@ def run_stm():
# warmup
# use naive eviction strategy
apply_adjust(model, model.p0, [model.p0], stateful_tensor_mgr)
mem_collector.sample_memstats()
mem_collector.sample_model_data()
mem_collector.sample_overall_data()
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
mem_collector.sample_memstats()
mem_collector.sample_model_data()
mem_collector.sample_overall_data()
apply_adjust(model, model.p2, [model.p1, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
mem_collector.sample_model_data()
mem_collector.sample_overall_data()
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
mem_collector.sample_model_data()
mem_collector.sample_overall_data()
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
mem_collector.sample_model_data()
mem_collector.finish_collection()
stateful_tensor_mgr.reset()
# warmup done
# use OPT-like eviction strategy
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
mem_collector.sample_memstats()
mem_collector.sample_model_data()
mem_collector.sample_overall_data()
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
mem_collector.sample_memstats()
mem_collector.sample_model_data()
mem_collector.sample_overall_data()
apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
mem_collector.sample_model_data()
mem_collector.sample_overall_data()
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
mem_collector.sample_model_data()
mem_collector.sample_overall_data()
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
mem_collector.sample_model_data()
mem_collector.finish_collection()
def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter],