diff --git a/colossalai/initialize.py b/colossalai/initialize.py index b806356e4..08bd43f62 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -275,9 +275,6 @@ def initialize(model: nn.Module, optimizer_config=optimizer_config) logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) - # FIXME() throw a warning if using zero with MP - if gpc.get_world_size(ParallelMode.MODEL) > 1: - logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0]) else: if isinstance(model, nn.Module): # first sync model across dp ranks diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index fa69a33f6..6e1720b3d 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -7,7 +7,8 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param, disposable) from .data_sampler import DataParallelSampler, get_dataloader -from .memory import report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, colo_device_memory_capacity +from .memory import (report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, + colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity) from .timer import MultiTimer, Timer from .tensor_detector import TensorDetector @@ -19,5 +20,5 @@ __all__ = [ 'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'DataParallelSampler', 'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint', - 'ensure_path_exists', 'disposable' + 'ensure_path_exists', 'disposable', 'colo_set_cpu_memory_capacity', 'colo_get_cpu_memory_capacity' ] diff --git a/colossalai/utils/memory.py b/colossalai/utils/memory.py index 12c23c3a0..434e90edd 100644 --- a/colossalai/utils/memory.py +++ b/colossalai/utils/memory.py @@ -11,6 +11,7 @@ from colossalai.logging import get_dist_logger from packaging import version _GLOBAL_CUDA_MEM_FRACTION = 1.0 +_GLOBAL_CPU_MEM_CAPACITY = -1 def _bytes_to_MB(val, decimal=2): @@ -106,9 +107,8 @@ def colo_device_memory_capacity(device: torch.device) -> int: """ assert isinstance(device, torch.device) if device.type == 'cpu': - mem_info = _get_cpu_memory_info() # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. - return mem_info.total / gpc.num_processes_on_current_node + return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node if device.type == 'cuda': return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION @@ -152,3 +152,27 @@ def colo_set_process_memory_fraction(ratio: float) -> None: global _GLOBAL_CUDA_MEM_FRACTION _GLOBAL_CUDA_MEM_FRACTION = ratio torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) + + +def colo_set_cpu_memory_capacity(size: int) -> None: + global _GLOBAL_CPU_MEM_CAPACITY + mem_info = _get_cpu_memory_info() + total_size = mem_info.total + if size <= total_size: + _GLOBAL_CPU_MEM_CAPACITY = size + else: + _GLOBAL_CPU_MEM_CAPACITY = total_size + + +def colo_get_cpu_memory_capacity() -> int: + """ + Get the cpu memory capacity. We may not use all of it. + Returns: + int: _description_ + """ + global _GLOBAL_CPU_MEM_CAPACITY + if _GLOBAL_CPU_MEM_CAPACITY == -1: + mem_info = _get_cpu_memory_info() + return mem_info.total + else: + return _GLOBAL_CPU_MEM_CAPACITY