[zero] refactor memstats collector (#706)

* refactor memstats collector

* fix disposable

* polish code
pull/710/head
ver217 3 years ago committed by GitHub
parent 3fc8a204dc
commit ab8c6b4a0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,7 +5,7 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage,
is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
sync_model_param)
sync_model_param, disposable)
from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient
from .memory_utils.memory_monitor import report_memory_usage
@ -19,5 +19,5 @@ __all__ = [
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
'ensure_path_exists'
'ensure_path_exists', 'disposable'
]

@ -4,8 +4,8 @@ import os
import random
import socket
from pathlib import Path
from typing import List, Union
from typing import Callable, List, Union
import functools
import torch
from torch._six import inf
from torch.nn.parameter import Parameter
@ -112,6 +112,7 @@ def conditional_context(context_manager, enable=True):
class model_branch_context(object):
def __enter__(self):
self.env_status = env.save()
@ -131,7 +132,7 @@ def _calc_l2_norm(grads):
colossal_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False # no per-parameter norm
False # no per-parameter norm
)
return norm
@ -328,3 +329,16 @@ def switch_virtual_pipeline_parallel_rank(rank):
yield
finally:
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
def disposable(func: Callable) -> Callable:
executed = False
@functools.wraps(func)
def wrapper(*args, **kwargs):
nonlocal executed
if not executed:
executed = True
return func(*args, **kwargs)
return wrapper

@ -1,36 +1,11 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
import torch
import time
from typing import List
class SamplingCounter:
def __init__(self) -> None:
self._samplint_cnt = 0
self._max_sampling_cnt = None
def advance(self):
self._samplint_cnt += 1
def next(self):
assert self._max_sampling_cnt is not None
return (self._samplint_cnt + 1) % self._max_sampling_cnt
def current(self):
return self._samplint_cnt
def max(self):
return self._max_sampling_cnt
def reset(self):
self._max_sampling_cnt = self._samplint_cnt
self._samplint_cnt = 0
class MemStatsCollector:
"""
A Memory statistic collector.
@ -44,7 +19,6 @@ class MemStatsCollector:
"""
def __init__(self) -> None:
self._sampling_cnter = SamplingCounter()
self._mem_monitor = AsyncMemoryMonitor()
self._model_data_cuda_list = []
self._overall_cuda_list = []
@ -57,6 +31,7 @@ class MemStatsCollector:
self._sampling_time = []
self._start_flag = False
self._period_idx = 0
def overall_mem_stats(self, device_type: str):
if device_type == 'cuda':
@ -106,15 +81,22 @@ class MemStatsCollector:
else:
raise TypeError
def current_non_model_data(self, device_type: str) -> int:
"""get the non model data of the current sampling moment
"""
return self.non_model_data_list(device_type)[self._sampling_cnter.current()]
def max_non_model_data(self, device_type: str) -> int:
"""Get max non model data memory usage of current sampling period
def next_non_model_data(self, device_type: str):
"""get the non model data of the next sampling moment
Args:
device_type (str): device type, can be 'cpu' or 'cuda'.
Returns:
int: max non model data memory usage of current sampling period
"""
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
assert len(self._sampling_time) > 0, 'Cannot get mem stats info before collection phase.'
next_period_idx = (self._period_idx + 1) % len(self._sampling_time)
current_non_model_data = self.non_model_data_list(device_type)[self._period_idx]
next_non_model_data = self.non_model_data_list(device_type)[next_period_idx]
self._period_idx = next_period_idx
return max(current_non_model_data, next_non_model_data)
@property
def sampling_time(self):
@ -126,6 +108,7 @@ class MemStatsCollector:
def finish_collection(self):
self._start_flag = False
self._mem_monitor.finish()
def sample_memstats(self) -> None:
"""
@ -134,8 +117,6 @@ class MemStatsCollector:
Advance the sampling cnter.
"""
if self._start_flag:
sampling_cnt = self._sampling_cnter.current()
assert sampling_cnt == len(self._overall_cuda_list)
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda_list.append(self._mem_monitor.finish())
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
@ -146,13 +127,6 @@ class MemStatsCollector:
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
self._sampling_time.append(time.time())
self._mem_monitor.start()
# TODO(ver217): refactor sampler
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
self._sampling_cnter.advance()
def reset_sampling_cnter(self) -> None:
self._sampling_cnter.reset()
self._mem_monitor.finish()
def clear(self) -> None:
self._model_data_cuda_list = []
@ -162,5 +136,4 @@ class MemStatsCollector:
self._overall_cpu_list = []
self._start_flag = False
self._sampling_cnter.reset()
self._mem_monitor.finish()
self._period_idx = 0

@ -1,16 +0,0 @@
from async_memtracer import AsyncMemoryMonitor
import torch
if __name__ == '__main__':
async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start()
output = OP1(input)
async_mem_monitor.finish()
async_mem_monitor.start()
output = OP2(output)
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')

@ -1,37 +0,0 @@
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
import torch
def test_mem_collector():
collector = MemStatsCollector()
collector.start_collection()
a = torch.randn(10).cuda()
# sampling at time 0
collector.sample_memstats()
m_a = torch.randn(10).cuda()
b = torch.randn(10).cuda()
# sampling at time 1
collector.sample_memstats()
a = b
# sampling at time 2
collector.sample_memstats()
collector.finish_collection()
collector.reset_sampling_cnter()
# do nothing after collection, just advance sampling cnter
collector.sample_memstats()
collector.sample_memstats()
print(collector.overall_mem_stats('cuda'))
if __name__ == '__main__':
test_mem_collector()

@ -71,8 +71,7 @@ class StatefulTensorMgr(object):
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = max(self._mem_stats_collector.current_non_model_data('cuda'),
self._mem_stats_collector.next_non_model_data('cuda'))
max_cuda_non_model_data_per_period = self._mem_stats_collector.max_non_model_data('cuda')
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data

@ -12,7 +12,7 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.engine.gradient_handler.utils import bucket_allreduce
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.utils import get_current_device, disposable
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
@ -112,10 +112,11 @@ class ShardedModelV2(nn.Module):
for param in submodule.parameters(recurse=False):
if hasattr(param, 'colo_attr'):
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
else:
self._memstats_collector = None
self._stateful_tensor_mgr = None
self._iter_cnter = 0
# Register hooks
self._ophook_list = [
@ -188,9 +189,9 @@ class ShardedModelV2(nn.Module):
f.write('\n')
def _pre_forward_operations(self):
if self._iter_cnter == 0 and self._memstats_collector:
# the operation will affect the memory tracer behavior in ZeroHook
self._memstats_collector.start_collection()
# the operation will affect the memory tracer behavior in ZeroHook
if self._memstats_collector:
self._start_collect_memstats()
for p in self.module.parameters():
if hasattr(p, 'colo_attr'):
@ -221,17 +222,14 @@ class ShardedModelV2(nn.Module):
ophook.post_iter()
def _update_memstats(self):
if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection()
if self._memstats_collector:
self._memstats_collector.reset_sampling_cnter()
self._finish_collect_memstats()
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = colo_cuda_memory_capacity() - max(
self._memstats_collector.overall_mem_stats('cuda'))
self._iter_cnter += 1
@torch.no_grad()
def _post_backward_operations(self) -> None:

@ -55,7 +55,6 @@ def run_stm():
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
mem_collector.finish_collection()
mem_collector.reset_sampling_cnter()
stateful_tensor_mgr.reset()
# warmup done

Loading…
Cancel
Save