diff --git a/colossalai/utils/commons/__init__.py b/colossalai/utils/commons/__init__.py deleted file mode 100644 index e48fad25c..000000000 --- a/colossalai/utils/commons/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .bucket_tensor_copy import BucketizedTensorCopy - -__all__ = ['BucketizedTensorCopy'] diff --git a/colossalai/utils/memory_tracer/async_memtracer.py b/colossalai/utils/memory_tracer/async_memtracer.py index 842aafbdd..4091f94aa 100644 --- a/colossalai/utils/memory_tracer/async_memtracer.py +++ b/colossalai/utils/memory_tracer/async_memtracer.py @@ -29,6 +29,10 @@ class AsyncMemoryMonitor: An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU at interval of 1/(10**power) sec. + The idea comes from Runtime Memory Tracer of PatrickStar + PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management + https://arxiv.org/abs/2108.05818 + :param power: the power of time interval, defaults to 10 :type power: int @@ -54,6 +58,7 @@ class AsyncMemoryMonitor: self.keep_measuring = False current_device = get_current_device() + def _set_cuda_device(): torch.cuda.set_device(current_device) diff --git a/colossalai/utils/commons/bucket_tensor_copy.py b/colossalai/utils/memory_utils/bucket_tensor_copy.py similarity index 100% rename from colossalai/utils/commons/bucket_tensor_copy.py rename to colossalai/utils/memory_utils/bucket_tensor_copy.py diff --git a/colossalai/utils/memory_utils/utils.py b/colossalai/utils/memory_utils/utils.py index b049a92ca..b1c24994c 100644 --- a/colossalai/utils/memory_utils/utils.py +++ b/colossalai/utils/memory_utils/utils.py @@ -5,12 +5,27 @@ from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DAT from typing import Union +_GLOBAL_CUDA_MEM_FRACTION = 1.0 -def colo_cuda_memory_capacity(): + +def colo_set_process_memory_fraction(ratio: float) -> None: + """colo_set_process_memory_fraction + + set how much cuda memory used on the gpu belonging to the current process. + + Args: + ratio (float): a ratio between 0. ~ 1. + """ + global _GLOBAL_CUDA_MEM_FRACTION + _GLOBAL_CUDA_MEM_FRACTION = ratio + torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) + + +def colo_cuda_memory_capacity() -> float: """ Get cuda memory capacity of the current cuda. """ - return torch.cuda.get_device_properties(get_current_device()).total_memory + return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t: Union[ShardedTensor, @@ -50,10 +65,25 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype) -def colo_model_data_move_to_cpu(t: torch.Tensor): - assert isinstance(t, torch.Tensor) - if t.device.type == 'cpu': +def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None: + """colo_model_data_move_to_cpu + + move a model data tensor from gpu to cpu + + Args: + t (Union[ShardedTensor, torch.Tensor]): _description_ + """ + + if isinstance(t, ShardedTensor): + t_payload = t.payload + elif isinstance(t, torch.Tensor): + t_payload = t + else: + raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}') + + if t_payload.device.type == 'cpu': return - GLOBAL_MODEL_DATA_TRACER.delete_tensor(t) - t.data = t.data.cpu() + # TODO() optimize the tensor moving with non-blocking + GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload) + t_payload.data = t_payload.data.cpu() diff --git a/tests/test_utils/test_bucket_tensor_copy.py b/tests/test_utils/test_bucket_tensor_copy.py index 31d534b78..f190cb522 100644 --- a/tests/test_utils/test_bucket_tensor_copy.py +++ b/tests/test_utils/test_bucket_tensor_copy.py @@ -1,4 +1,4 @@ -from colossalai.utils.commons import BucketizedTensorCopy +from colossalai.utils.memory_utils.bucket_tensor_copy import BucketizedTensorCopy from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.utils import free_port import torch