mirror of https://github.com/hpcaitech/ColossalAI
[gemini] APIs to set cpu memory capacity (#809)
parent
f6dcd23fb9
commit
227d1cd4b3
|
@ -275,9 +275,6 @@ def initialize(model: nn.Module,
|
|||
optimizer_config=optimizer_config)
|
||||
|
||||
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
|
||||
# FIXME() throw a warning if using zero with MP
|
||||
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
|
||||
else:
|
||||
if isinstance(model, nn.Module):
|
||||
# first sync model across dp ranks
|
||||
|
|
|
@ -7,7 +7,8 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
|
|||
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
|
||||
sync_model_param, disposable)
|
||||
from .data_sampler import DataParallelSampler, get_dataloader
|
||||
from .memory import report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, colo_device_memory_capacity
|
||||
from .memory import (report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction,
|
||||
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
|
||||
from .timer import MultiTimer, Timer
|
||||
from .tensor_detector import TensorDetector
|
||||
|
||||
|
@ -19,5 +20,5 @@ __all__ = [
|
|||
'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction',
|
||||
'Timer', 'MultiTimer', 'multi_tensor_applier', 'DataParallelSampler', 'get_dataloader',
|
||||
'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
|
||||
'ensure_path_exists', 'disposable'
|
||||
'ensure_path_exists', 'disposable', 'colo_set_cpu_memory_capacity', 'colo_get_cpu_memory_capacity'
|
||||
]
|
||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.logging import get_dist_logger
|
|||
from packaging import version
|
||||
|
||||
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
||||
_GLOBAL_CPU_MEM_CAPACITY = -1
|
||||
|
||||
|
||||
def _bytes_to_MB(val, decimal=2):
|
||||
|
@ -106,9 +107,8 @@ def colo_device_memory_capacity(device: torch.device) -> int:
|
|||
"""
|
||||
assert isinstance(device, torch.device)
|
||||
if device.type == 'cpu':
|
||||
mem_info = _get_cpu_memory_info()
|
||||
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
|
||||
return mem_info.total / gpc.num_processes_on_current_node
|
||||
return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node
|
||||
if device.type == 'cuda':
|
||||
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
|
||||
|
||||
|
@ -152,3 +152,27 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
|
|||
global _GLOBAL_CUDA_MEM_FRACTION
|
||||
_GLOBAL_CUDA_MEM_FRACTION = ratio
|
||||
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())
|
||||
|
||||
|
||||
def colo_set_cpu_memory_capacity(size: int) -> None:
|
||||
global _GLOBAL_CPU_MEM_CAPACITY
|
||||
mem_info = _get_cpu_memory_info()
|
||||
total_size = mem_info.total
|
||||
if size <= total_size:
|
||||
_GLOBAL_CPU_MEM_CAPACITY = size
|
||||
else:
|
||||
_GLOBAL_CPU_MEM_CAPACITY = total_size
|
||||
|
||||
|
||||
def colo_get_cpu_memory_capacity() -> int:
|
||||
"""
|
||||
Get the cpu memory capacity. We may not use all of it.
|
||||
Returns:
|
||||
int: _description_
|
||||
"""
|
||||
global _GLOBAL_CPU_MEM_CAPACITY
|
||||
if _GLOBAL_CPU_MEM_CAPACITY == -1:
|
||||
mem_info = _get_cpu_memory_info()
|
||||
return mem_info.total
|
||||
else:
|
||||
return _GLOBAL_CPU_MEM_CAPACITY
|
||||
|
|
Loading…
Reference in New Issue