[gemini] APIs to set cpu memory capacity (#809)

pull/810/head
Jiarui Fang 2022-04-19 16:05:22 +08:00 committed by GitHub
parent f6dcd23fb9
commit 227d1cd4b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 7 deletions

View File

@ -275,9 +275,6 @@ def initialize(model: nn.Module,
optimizer_config=optimizer_config)
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
# FIXME() throw a warning if using zero with MP
if gpc.get_world_size(ParallelMode.MODEL) > 1:
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
else:
if isinstance(model, nn.Module):
# first sync model across dp ranks

View File

@ -7,7 +7,8 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
sync_model_param, disposable)
from .data_sampler import DataParallelSampler, get_dataloader
from .memory import report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, colo_device_memory_capacity
from .memory import (report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction,
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
from .timer import MultiTimer, Timer
from .tensor_detector import TensorDetector
@ -19,5 +20,5 @@ __all__ = [
'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction',
'Timer', 'MultiTimer', 'multi_tensor_applier', 'DataParallelSampler', 'get_dataloader',
'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
'ensure_path_exists', 'disposable'
'ensure_path_exists', 'disposable', 'colo_set_cpu_memory_capacity', 'colo_get_cpu_memory_capacity'
]

View File

@ -11,6 +11,7 @@ from colossalai.logging import get_dist_logger
from packaging import version
_GLOBAL_CUDA_MEM_FRACTION = 1.0
_GLOBAL_CPU_MEM_CAPACITY = -1
def _bytes_to_MB(val, decimal=2):
@ -106,9 +107,8 @@ def colo_device_memory_capacity(device: torch.device) -> int:
"""
assert isinstance(device, torch.device)
if device.type == 'cpu':
mem_info = _get_cpu_memory_info()
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
return mem_info.total / gpc.num_processes_on_current_node
return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node
if device.type == 'cuda':
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
@ -152,3 +152,27 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
global _GLOBAL_CUDA_MEM_FRACTION
_GLOBAL_CUDA_MEM_FRACTION = ratio
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())
def colo_set_cpu_memory_capacity(size: int) -> None:
global _GLOBAL_CPU_MEM_CAPACITY
mem_info = _get_cpu_memory_info()
total_size = mem_info.total
if size <= total_size:
_GLOBAL_CPU_MEM_CAPACITY = size
else:
_GLOBAL_CPU_MEM_CAPACITY = total_size
def colo_get_cpu_memory_capacity() -> int:
"""
Get the cpu memory capacity. We may not use all of it.
Returns:
int: _description_
"""
global _GLOBAL_CPU_MEM_CAPACITY
if _GLOBAL_CPU_MEM_CAPACITY == -1:
mem_info = _get_cpu_memory_info()
return mem_info.total
else:
return _GLOBAL_CPU_MEM_CAPACITY