[utils] correct cpu memory used and capacity in the context of multi-process (#726)

pull/739/head
Jiarui Fang 2022-04-12 14:57:54 +08:00 committed by GitHub
parent 7db3ccc79b
commit 53cb584808
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 52 additions and 20 deletions

View File

@ -8,6 +8,7 @@ from colossalai.utils import get_current_device
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from packaging import version
_GLOBAL_CUDA_MEM_FRACTION = 1.0 _GLOBAL_CUDA_MEM_FRACTION = 1.0
@ -106,7 +107,8 @@ def colo_device_memory_capacity(device: torch.device) -> int:
assert isinstance(device, torch.device) assert isinstance(device, torch.device)
if device.type == 'cpu': if device.type == 'cpu':
mem_info = _get_cpu_memory_info() mem_info = _get_cpu_memory_info()
return mem_info.info.total / gpc.get_world_size(ParallelMode.DATA) # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
return mem_info.total / gpc.num_processes_on_current_node
if device.type == 'cuda': if device.type == 'cuda':
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
@ -123,8 +125,9 @@ def colo_device_memory_used(device: torch.device) -> int:
""" """
if device.type == 'cpu': if device.type == 'cpu':
mem_info = _get_cpu_memory_info() mem_info = _get_cpu_memory_info()
# FIXME(jiaruifang) we need get how many processes are using the CPU memory. # In the context of 1-CPU-N-GPU, the memory usage of the current process is 1/N CPU memory used.
ret = mem_info.used / gpc.get_world_size(ParallelMode.DATA) # Each process consumes the same amount of memory.
ret = mem_info.used / gpc.num_processes_on_current_node
return ret return ret
elif device.type == 'cuda': elif device.type == 'cuda':
ret: int = torch.cuda.memory_allocated(device) ret: int = torch.cuda.memory_allocated(device)
@ -142,6 +145,10 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
Args: Args:
ratio (float): a ratio between 0. ~ 1. ratio (float): a ratio between 0. ~ 1.
""" """
if version.parse(torch.__version__) < version.parse('1.8'):
logger = get_dist_logger('colo_set_process_memory_fraction')
logger.warning('colo_set_process_memory_fraction failed because torch version is less than 1.8')
return
global _GLOBAL_CUDA_MEM_FRACTION global _GLOBAL_CUDA_MEM_FRACTION
_GLOBAL_CUDA_MEM_FRACTION = ratio _GLOBAL_CUDA_MEM_FRACTION = ratio
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())

View File

@ -31,6 +31,7 @@ class AsyncMemoryMonitor:
async_mem_monitor.finish() async_mem_monitor.finish()
async_mem_monitor.save('log.pkl') async_mem_monitor.save('log.pkl')
Args: Args:
power (int, optional): the power of time interva. Defaults to 10. power (int, optional): the power of time interva. Defaults to 10.

View File

@ -16,7 +16,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_on_exception
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_zero_data_parallel.common import CONFIG from tests.test_zero.common import CONFIG
class MoeModel(CheckpointModule): class MoeModel(CheckpointModule):

View File

@ -16,7 +16,7 @@ from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.context import MOE_CONTEXT from colossalai.context import MOE_CONTEXT
from colossalai.testing import assert_equal_in_group from colossalai.testing import assert_equal_in_group
from tests.test_zero_data_parallel.common import CONFIG, check_grads_padding, run_fwd_bwd from tests.test_zero.common import CONFIG, check_grads_padding, run_fwd_bwd
from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_moe.test_moe_zero_init import MoeModel

View File

@ -20,7 +20,7 @@ from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.context import MOE_CONTEXT from colossalai.context import MOE_CONTEXT
from colossalai.testing import assert_equal_in_group from colossalai.testing import assert_equal_in_group
from tests.test_zero_data_parallel.common import CONFIG, check_sharded_model_params from tests.test_zero.common import CONFIG, check_sharded_model_params
from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_moe.test_moe_zero_init import MoeModel

View File

@ -0,0 +1,32 @@
import pytest
import colossalai
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity
from colossalai.utils import free_port
from functools import partial
import torch.multiprocessing as mp
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
frac1 = colo_device_memory_capacity(get_current_device())
colo_set_process_memory_fraction(0.5)
frac2 = colo_device_memory_capacity(get_current_device())
assert frac2 * 2 == frac1
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_colo_set_process_memory_fraction_and_colo_device_memory_capacity()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4, 5])
def test_memory_utils(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_memory_utils(world_size=2)

View File

@ -14,7 +14,7 @@ from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.zero.sharded_optim._utils import has_inf_or_nan from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_zero_data_parallel.test_sharded_optim_v2 import _run_step from tests.test_zero.test_sharded_optim_v2 import _run_step
from common import CONFIG from common import CONFIG

View File

@ -11,7 +11,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_on_exception
from tests.test_zero_data_parallel.common import CONFIG, allclose from tests.test_zero.common import CONFIG, allclose
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor from colossalai.zero.sharded_param.tensorful_state import StatefulTensor

View File

@ -5,7 +5,6 @@ from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move, from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move,
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
colo_model_tensor_clone) colo_model_tensor_clone)
from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity
from colossalai.utils import free_port from colossalai.utils import free_port
import torch import torch
@ -32,13 +31,6 @@ def _run_colo_tensor_mem_usage():
assert g1 * 4 == g2 assert g1 * 4 == g2
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
frac1 = colo_device_memory_capacity(get_current_device())
colo_set_process_memory_fraction(0.5)
frac2 = colo_device_memory_capacity(get_current_device())
assert frac2 * 2 == frac1
def _run_colo_model_data_tensor_move_inline(): def _run_colo_model_data_tensor_move_inline():
for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]: for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]:
colo_model_data_tensor_move_inline(t, get_current_device()) colo_model_data_tensor_move_inline(t, get_current_device())
@ -82,20 +74,20 @@ def _run_colo_model_tensor_clone():
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_colo_set_process_memory_fraction_and_colo_device_memory_capacity()
_run_colo_tensor_mem_usage()
_run_colo_model_data_tensor_move_inline() _run_colo_model_data_tensor_move_inline()
_run_colo_model_data_tensor_move() _run_colo_model_data_tensor_move()
_run_colo_tensor_mem_usage()
_run_colo_model_data_move_to_cpu() _run_colo_model_data_move_to_cpu()
_run_colo_model_tensor_clone() _run_colo_model_tensor_clone()
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [4, 5]) @pytest.mark.parametrize("world_size", [4, 5])
def test_tensor_move(world_size): def test_zero_tensor_utils(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_tensor_move(4) test_zero_tensor_utils(world_size=2)