[tensor] chunk manager monitor mem usage (#1076)

pull/1075/head^2
ver217 2022-06-07 15:00:00 +08:00 committed by GitHub
parent 98cdbf49c6
commit 1b17859328
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 1 deletions

View File

@ -54,6 +54,7 @@ class Chunk:
if not self.is_src_rank:
self.data.storage().resize_(0)
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
self.mem = self.size * self.data.element_size()
def append(self, tensor: torch.Tensor) -> None:
assert tensor.dtype == self.dtype
@ -167,6 +168,10 @@ class Chunk:
self.data.copy_(dest_chunk.data)
self._update_tensors_ptr()
@property
def device_type(self) -> str:
return self.data.device.type
class ChunkManager:
@ -184,6 +189,7 @@ class ChunkManager:
self.lazy_release_tensors: List[torch.Tensor] = []
if enable_distributed_storage and chunk_size is None:
self.rank_load: Dict[str, torch.Tensor] = {}
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
assert tensor not in self.tensor_chunk_map
@ -202,6 +208,8 @@ class ChunkManager:
self.rank_load[group_name][src_rank] += chunk_size
self.chunk_groups[group_name].append(chunk)
chunk.append(tensor)
if not chunk.is_free:
self.total_mem[chunk.device_type] += chunk.mem
self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1]
if not self.enable_distributed_storage:
self.accessed_chunks.add(self.chunk_groups[group_name][-1])
@ -222,8 +230,11 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor]
if chunk in self.accessed_chunks:
return
if not chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem
chunk.access()
self.accessed_chunks.add(chunk)
self.total_mem[chunk.device_type] += chunk.mem
def release_chunk(self, tensor: torch.Tensor) -> None:
if not self.enable_distributed_storage:
@ -234,11 +245,17 @@ class ChunkManager:
if chunk.can_release:
chunk.release()
self.accessed_chunks.remove(chunk)
if chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
chunk = self.tensor_chunk_map[tensor]
if chunk.can_move_device:
if chunk.data.device == device:
return
if chunk.can_move_device and not chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem
chunk.move_device(device)
self.total_mem[chunk.device_type] += chunk.mem
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
chunk = self.tensor_chunk_map[tensor]
@ -248,7 +265,9 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor]
if not chunk.can_reduce:
return False
self.total_mem[chunk.device_type] -= chunk.mem
chunk.reduce(is_all_reduce=not self.enable_distributed_storage)
self.total_mem[chunk.device_type] += chunk.mem
return True
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
@ -272,6 +291,7 @@ class ChunkManager:
def __repr__(self) -> str:
msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n'
msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
for group_name, group in self.chunk_groups.items():
msg += f'Group {group_name}:\n'
for i, chunk in enumerate(group):

View File

@ -32,6 +32,8 @@ HAS_TENSORS = {
}
}
TOTAL_MEM = {True: {True: [8192, 8192], False: [16384, 16384]}, False: {True: [8192, 4096], False: [12288, 12288]}}
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
@ -42,15 +44,27 @@ def run_chunk_zero(use_chunk, use_zero):
params = [torch.rand(32, 32) for _ in range(3)]
chunk_size = 2048 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0
for p in params:
chunk_manager.append_tensor(p, 'param')
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
for p in params:
chunk_manager.access_chunk(p)
check_has_params(params, [True, True, True])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
for p in params:
chunk_manager.release_chunk(p)
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
for p in params:
chunk_manager.move_chunk(p, torch.device('cpu'))
assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
assert chunk_manager.total_mem['cuda'] == 0
def run_dist(rank, world_size, port):