[refactor] memory utils (#577)

pull/621/head
Jiarui Fang 2022-04-01 09:22:33 +08:00 committed by GitHub
parent 104cbbb313
commit e956d93ac2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 261 additions and 202 deletions

View File

@ -29,6 +29,7 @@ class MoeGradientHandler(BaseGradientHandler):
if global_data > 1:
epsize_param_dict = get_moe_epsize_param_dict(self._model)
# epsize is 1, indicating the params are replicated among processes in data parallelism
# use the ParallelMode.DATA to get data parallel group
# reduce gradients for all parameters in data parallelism

View File

@ -10,8 +10,7 @@ from colossalai.zero.sharded_param.tensorful_state import TensorState
from ._base_ophook import BaseOpHook
from colossalai.utils.memory_utils.utils import \
colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
@OPHOOKS.register_module

View File

@ -4,8 +4,8 @@ import pickle
import torch
from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.utils import get_current_device
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
class AsyncMemoryMonitor:
@ -82,7 +82,7 @@ class AsyncMemoryMonitor:
while self.keep_measuring:
max_usage = max(
max_usage,
colo_cuda_memory_used(),
colo_device_memory_used(get_current_device()),
)
sleep(self.interval)
return max_usage

View File

@ -1,9 +1,9 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.utils import get_current_device
import torch
from typing import Tuple
from typing import List
class SamplingCounter:
@ -23,45 +23,71 @@ class SamplingCounter:
class MemStatsCollector:
"""
A Memory statistic collector.
It works in two phases.
Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU.
The first iteration of DNN training.
Phase 2. Runtime Phase: use the read-only collected stats
The rest iterations of DNN training.
It has a Sampling counter which is reset after DNN training iteration.
"""
def __init__(self) -> None:
"""
Collecting Memory Statistics.
It has two phases.
1. Collection Phase: collect memory usage statistics
2. Runtime Phase: do not collect statistics.
"""
self._sampling_cnter = SamplingCounter()
self._model_data_cuda = []
self._overall_cuda = []
self._model_data_cuda_list = []
self._overall_cuda_list = []
# TODO(jiaruifang) Now no cpu mem stats collecting
self._model_data_cpu = []
self._overall_cpu = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._start_flag = False
@property
def overall_cuda(self):
return self._overall_cuda
def overall_mem_stats(self, device_type: str):
if device_type == 'cuda':
return self._overall_cuda_list
elif device_type == 'cpu':
return self._overall_cpu_list
else:
raise TypeError
@property
def model_data_cuda_GB(self):
return [elem / 1e9 for elem in self._model_data_cuda]
def model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
scale = 1
if unit == 'GB':
scale = 1e9
elif unit == 'MB':
scale = 1e6
elif unit == 'KB':
scale = 1e3
else:
raise TypeError
@property
def model_data_cuda(self):
return self._model_data_cuda
if device_type == 'cuda':
return [elem / scale for elem in self._model_data_cuda_list]
elif device_type == 'cpu':
return [elem / scale for elem in self._model_data_cpu_list]
else:
raise TypeError
@property
def non_model_data_cuda_GB(self):
return [elem / 1e9 for elem in self.non_model_data_cuda]
@property
def non_model_data_cuda(self):
def non_model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
"""Non model data stats
"""
return [(v1 - v2) for v1, v2 in zip(self._overall_cuda, self._model_data_cuda)]
scale = 1
if unit == 'GB':
scale = 1e9
elif unit == 'MB':
scale = 1e6
elif unit == 'KB':
scale = 1e3
if device_type == 'cuda':
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cuda_list, self._model_data_cuda_list)]
elif device_type == 'cpu':
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cpu_list, self._model_data_cpu_list)]
else:
raise TypeError
def start_collection(self):
self._start_flag = True
@ -73,32 +99,28 @@ class MemStatsCollector:
"""
Sampling memory statistics.
Record the current model data CUDA memory usage as well as system CUDA memory usage.
Advance the sampling cnter.
"""
if self._start_flag:
sampling_cnt = self._sampling_cnter.sampling_cnt
assert sampling_cnt == len(self._overall_cuda)
self._model_data_cuda.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda.append(colo_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
self._sampling_cnter.advance()
assert sampling_cnt == len(self._overall_cuda_list)
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda_list.append(colo_device_memory_used(get_current_device()))
def fetch_memstats(self) -> Tuple[int, int]:
"""
returns cuda usage of model data and overall cuda usage.
"""
sampling_cnt = self._sampling_cnter.sampling_cnt
if len(self._model_data_cuda) < sampling_cnt:
raise RuntimeError
return (self._model_data_cuda[sampling_cnt], self._overall_cuda[sampling_cnt])
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
self._sampling_cnter.advance()
def reset_sampling_cnter(self) -> None:
self._sampling_cnter.reset()
def clear(self) -> None:
self._model_data_cuda = []
self._overall_cuda = []
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu = []
self._overall_cpu = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._start_flag = False
self._sampling_cnter.reset()

View File

@ -30,10 +30,7 @@ def test_mem_collector():
collector.sample_memstats()
collector.sample_memstats()
cuda_use, overall_use = collector.fetch_memstats()
print(cuda_use, overall_use)
print(collector.overall_cuda)
print(collector.overall_mem_stats('cuda'))
if __name__ == '__main__':

View File

@ -9,29 +9,6 @@ import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils.cuda import get_current_device
from typing import Optional
def colo_cuda_memory_used(device: Optional[torch.device] = None) -> int:
"""Get the free memory info of device.
Args:
device (Optional[``torch.device``]): a torch device instance or None. Defaults None.
Returns:
int: current memory usage, sized by Byte.
"""
if device:
assert device.type == 'cuda'
else:
device = torch.device(f'cuda:{get_current_device()}')
ret: int = torch.cuda.memory_allocated(device)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats(device)
return ret
def bytes_to_GB(val, decimal=2):

View File

@ -1,29 +1,65 @@
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from typing import Tuple, Union
from collections import namedtuple
import psutil
from colossalai.core import global_context as gpc
_GLOBAL_CUDA_MEM_FRACTION = 1.0
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
if issubclass(type(tensor), StatefulTensor):
t = tensor.payload
elif isinstance(tensor, torch.Tensor):
t = tensor
else:
return 0, 0
# copy from PatrickStar
def _get_cpu_memory_info():
ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"])
try:
# psutil reads the memory info from /proc/memory_info,
# which results in returning the host memory instead of
# that of container.
# Here we try to read the container memory with method in:
# https://stackoverflow.com/a/46213331/5163915
mems = {}
with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f:
for line in f:
fields = line.split()
mems[fields[0]] = int(fields[1]) * 1024
total = mems[b"MemTotal:"]
free = mems[b"MemFree:"]
cached = mems[b"Cached:"]
buffers = mems[b"Buffers:"]
used = total - free - cached - buffers
if used < 0:
used = total - free
mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used)
except FileNotFoundError:
mems = psutil.virtual_memory()
mem_info = ps_mem_info(
total=mems.total,
free=mems.free,
cached=mems.cached,
buffers=mems.buffers,
used=mems.used,
)
return mem_info
cuda_use, cpu_use = 0, 0
mem_use = t.numel() * t.element_size()
if t.device.type == 'cuda':
cuda_use += mem_use
elif t.device.type == 'cpu':
cpu_use += mem_use
return cuda_use, cpu_use
def colo_device_memory_used(device) -> int:
if not isinstance(device, torch.device):
device = torch.device(f"cuda:{device}")
if device.type == 'cpu':
mem_info = _get_cpu_memory_info()
# FIXME(jiaruifang) only work for 1-CPU multi-GPU
# CPU memory is sharded with all processes
# Not support multi-GPU multi-CPU
# We need a local_world_size here
ret = mem_info.used / gpc.get_world_size(ParallelMode.DATA)
return ret
elif device.type == 'cuda':
ret: int = torch.cuda.memory_allocated(device)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats(device)
return ret
def colo_set_process_memory_fraction(ratio: float) -> None:
@ -44,97 +80,3 @@ def colo_cuda_memory_capacity() -> float:
Get cuda memory capacity of the current cuda.
"""
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor,
torch.Tensor]) -> None:
"""
A colossal API for model data tensor move.
The src and target tensors could be resident on both CPU and GPU.
NOTE() The source tensor payload will be removed after this function.
The function will record the communication volume between CPU and GPU.
Args:
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
"""
if issubclass(type(src_t), StatefulTensor):
src_t_payload = src_t.payload
else:
src_t_payload = src_t.data
src_dev = src_t_payload.device
if issubclass(type(tgt_t), StatefulTensor):
tgt_t_payload = tgt_t.payload
else:
tgt_t_payload = tgt_t.data
tgt_t_payload.copy_(src_t_payload)
# remove payload of src_t
if issubclass(type(src_t), StatefulTensor):
src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype))
else:
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
int]) -> None:
"""
move a tensor to the target_device
Args:
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
"""
if isinstance(t, torch.Tensor):
t_payload = t
elif issubclass(type(t), StatefulTensor):
t_payload = t.payload
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if isinstance(target_device, int):
target_device = torch.device(f'cuda:{target_device}')
# deal with torch.device('cpu') and torch.device('cpu:0)
if t_payload.device.type == target_device.type:
return
t_payload.data = t_payload.data.to(target_device)
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
"""colo_model_data_move_to_cpu
move a model data tensor from gpu to cpu
Args:
t (Union[StatefulTensor, torch.Tensor]): _description_
"""
if issubclass(type(t), StatefulTensor):
t_payload = t.payload
elif isinstance(t, torch.Tensor):
t_payload = t
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if t_payload.device.type == 'cpu':
return
# TODO() optimize the tensor moving with non-blocking
t_payload.data = t_payload.data.cpu()
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
"""
Clone a model data tensor
Args:
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
target_device (torch.device): the target device
Returns:
torch.Tensor: a cloned torch tensor
"""
t_payload = t.payload if issubclass(type(t), StatefulTensor) else t
ret = t_payload.to(target_device)
return ret

View File

@ -3,7 +3,7 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.commons import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor

View File

@ -0,0 +1,117 @@
import torch
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from typing import Union, Tuple
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
if issubclass(type(tensor), StatefulTensor):
t = tensor.payload
elif isinstance(tensor, torch.Tensor):
t = tensor
else:
return 0, 0
cuda_use, cpu_use = 0, 0
mem_use = t.numel() * t.element_size()
if t.device.type == 'cuda':
cuda_use += mem_use
elif t.device.type == 'cpu':
cpu_use += mem_use
return cuda_use, cpu_use
def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor,
torch.Tensor]) -> None:
"""
A colossal API for model data tensor move.
The src and target tensors could be resident on both CPU and GPU.
NOTE() The source tensor payload will be removed after this function.
The function will record the communication volume between CPU and GPU.
Args:
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
"""
if issubclass(type(src_t), StatefulTensor):
src_t_payload = src_t.payload
else:
src_t_payload = src_t.data
src_dev = src_t_payload.device
if issubclass(type(tgt_t), StatefulTensor):
tgt_t_payload = tgt_t.payload
else:
tgt_t_payload = tgt_t.data
tgt_t_payload.copy_(src_t_payload)
# remove payload of src_t
if issubclass(type(src_t), StatefulTensor):
src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype))
else:
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
int]) -> None:
"""
move a tensor to the target_device
Args:
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
target_device: a traget device, if type is int, it the index of cuda card.
"""
if isinstance(t, torch.Tensor):
t_payload = t
elif issubclass(type(t), StatefulTensor):
t_payload = t.payload
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if not isinstance(target_device, torch.device):
target_device = torch.device(f'cuda:{target_device}')
# deal with torch.device('cpu') and torch.device('cpu:0)
if t_payload.device.type == target_device.type:
return
t_payload.data = t_payload.data.to(target_device)
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
"""colo_model_data_move_to_cpu
move a model data tensor from gpu to cpu
Args:
t (Union[StatefulTensor, torch.Tensor]): _description_
"""
if issubclass(type(t), StatefulTensor):
t_payload = t.payload
elif isinstance(t, torch.Tensor):
t_payload = t
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if t_payload.device.type == 'cpu':
return
# TODO() optimize the tensor moving with non-blocking
t_payload.data = t_payload.data.cpu()
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
"""
Clone a model data tensor
Args:
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
target_device (torch.device): the target device
Returns:
torch.Tensor: a cloned torch tensor
"""
t_payload = t.payload if issubclass(type(t), StatefulTensor) else t
ret = t_payload.to(target_device)
return ret

View File

@ -16,8 +16,9 @@ from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu)
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_move_to_cpu
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param.tensorful_state import TensorState
from torch.distributed import ProcessGroup
@ -160,11 +161,13 @@ class ShardedModelV2(nn.Module):
with open(filename, 'w+') as f:
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n')
f.write('model data\n')
f.write(str(self._memstats_collector.model_data_cuda_GB))
f.write('CUDA model data (GB)\n')
f.write(str(self._memstats_collector.model_data_cuda_list('cuda', 'GB')))
f.write('\n')
f.write('non model data\n')
f.write(str(self._memstats_collector.non_model_data_cuda_GB))
f.write('CUDA non model data (GB)\n')
f.write(str(self._memstats_collector.non_model_data_cuda_list('cuda', 'GB')))
f.write('CPU non model data (GB)\n')
f.write(str(self._memstats_collector.non_model_data_cuda_list('cpu', 'GB')))
f.write('\n')
def _pre_forward_operations(self):
@ -209,7 +212,8 @@ class ShardedModelV2(nn.Module):
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = colo_cuda_memory_capacity() - max(self._memstats_collector.overall_cuda)
self._cuda_margin_space = colo_cuda_memory_capacity() - max(
self._memstats_collector.overall_mem_stats('cuda'))
self._iter_cnter += 1
@torch.no_grad()

View File

@ -12,12 +12,13 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.shard_utils.tensor_utils import (colo_model_tensor_clone, colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter

View File

@ -1,8 +1,7 @@
import torch
import torch.distributed as dist
from colossalai.zero.sharded_param import ShardedTensor
from typing import Optional, Tuple
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage
from .tensorful_state import StatefulTensor, TensorState

View File

@ -1,4 +1,4 @@
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.zero.sharded_param import ShardedTensor

View File

@ -1,7 +1,7 @@
import pytest
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port
from colossalai.zero.sharded_param import ShardedTensor
import colossalai

View File

@ -13,7 +13,7 @@ from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import \
colo_model_mem_usage
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from tests.components_to_test.registry import non_distributed_component_funcs
@ -51,10 +51,10 @@ def run_model_test(init_device_type, shard_strategy_class):
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
cuda_mem_use, cpu_mem_use = colo_model_mem_usage(model)
cuda_mem_use, _ = colo_model_mem_usage(model)
model_data_cuda_mem_MB = cuda_mem_use / 1e6
logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6
sys_cuda_mem_MB = colo_device_memory_used(get_current_device()) / 1e6
logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
logger.info(f"Model Number Parameter {model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])