[gemini] collect cpu-gpu moving volume in each iteration (#813)

pull/815/head
Jiarui Fang 2022-04-20 11:29:48 +08:00 committed by GitHub
parent 61c20b44bc
commit 3ddbd1bce1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 9 deletions

View File

@ -28,6 +28,8 @@ class StatefulTensorMgr(object):
self._compute_list: List[StatefulTensor] = []
self._compute_idx: int = -1
self._cpu_gpu_move_volume = 0
def register_stateful_param(self, param: ShardedParamV2) -> None:
assert isinstance(param, ShardedParamV2)
for t in param.get_payload_tensors():
@ -56,7 +58,7 @@ class StatefulTensorMgr(object):
cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
else:
raise RuntimeError
self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list,
self._cpu_gpu_move_volume += self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list,
cuda_demand=cuda_demand,
warmup=self._warmup,
compute_list=self._compute_list,
@ -64,12 +66,18 @@ class StatefulTensorMgr(object):
# move COMPUTE tensors to CUDA
for t in move_to_cuda_tensor_list:
colo_model_data_tensor_move_inline(t, get_current_device())
self._cpu_gpu_move_volume += t.payload.numel() * t.payload.element_size()
@property
def cpu_gpu_move_volume(self):
return self._cpu_gpu_move_volume
def reset(self):
"""This function must be called when each iteration finishes
"""
self._warmup = False
self._compute_idx = -1
self._cpu_gpu_move_volume = 0
def _trans_state(self, trans_state_func, stateful_tensor, state):
trans_state_func(state)

View File

@ -27,9 +27,12 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy):
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector)
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
volume = 0
for t in hold_cuda_tensor_list:
colo_model_data_tensor_move_inline(t, self.device)
volume += t.payload.numel() * t.payload.element_size()
return volume
class CUDATensorPlacementPolicy(TensorPlacementPolicy):
@ -38,8 +41,8 @@ class CUDATensorPlacementPolicy(TensorPlacementPolicy):
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
pass
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
return 0
class AutoTensorPlacementPolicy(TensorPlacementPolicy):
@ -57,7 +60,24 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
warmup: bool = True,
compute_list: List[StatefulTensor] = [],
compute_idx: int = 0,
**kwargs) -> None:
**kwargs) -> int:
"""
Evict tensors from CUDA device.
Args:
hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
compute_idx (int, optional): the idx of computing device. Defaults to 0.
Raises:
RuntimeError:
Returns:
int: the volume of memory that is evicted
"""
volume = 0
cuda_capacity = colo_device_memory_capacity(get_current_device())
used_cuda_model_data = GLOBAL_MODEL_DATA_TRACER.cuda_usage
if warmup:
@ -87,11 +107,14 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
break
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
volume += t.payload.numel() * t.payload.element_size()
if freed_cuda_model_data < to_free_cuda_model_data:
raise RuntimeError(
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
)
return volume
class TensorPlacementPolicyFactory:

View File

@ -2,6 +2,7 @@ from typing import Optional
import torch
import torch.distributed as dist
from colossalai.logging import get_dist_logger
from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
@ -27,6 +28,7 @@ class ZeroHook(BaseOpHook):
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
process_group: Optional[dist.ProcessGroup] = None):
super().__init__()
self.logger = get_dist_logger("ZeROHook")
self.shard_strategy = shard_strategy
self.process_group = process_group
@ -112,4 +114,6 @@ class ZeroHook(BaseOpHook):
def post_iter(self):
if self._stateful_tensor_mgr:
self.logger.info(
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB", ranks=[0])
self._stateful_tensor_mgr.reset()