mirror of https://github.com/hpcaitech/ColossalAI
[tensor] chunk manager monitor mem usage (#1076)
parent
98cdbf49c6
commit
1b17859328
|
@ -54,6 +54,7 @@ class Chunk:
|
|||
if not self.is_src_rank:
|
||||
self.data.storage().resize_(0)
|
||||
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
||||
self.mem = self.size * self.data.element_size()
|
||||
|
||||
def append(self, tensor: torch.Tensor) -> None:
|
||||
assert tensor.dtype == self.dtype
|
||||
|
@ -167,6 +168,10 @@ class Chunk:
|
|||
self.data.copy_(dest_chunk.data)
|
||||
self._update_tensors_ptr()
|
||||
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
return self.data.device.type
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
|
||||
|
@ -184,6 +189,7 @@ class ChunkManager:
|
|||
self.lazy_release_tensors: List[torch.Tensor] = []
|
||||
if enable_distributed_storage and chunk_size is None:
|
||||
self.rank_load: Dict[str, torch.Tensor] = {}
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
|
||||
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
|
@ -202,6 +208,8 @@ class ChunkManager:
|
|||
self.rank_load[group_name][src_rank] += chunk_size
|
||||
self.chunk_groups[group_name].append(chunk)
|
||||
chunk.append(tensor)
|
||||
if not chunk.is_free:
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1]
|
||||
if not self.enable_distributed_storage:
|
||||
self.accessed_chunks.add(self.chunk_groups[group_name][-1])
|
||||
|
@ -222,8 +230,11 @@ class ChunkManager:
|
|||
chunk = self.tensor_chunk_map[tensor]
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
if not chunk.is_free:
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.access()
|
||||
self.accessed_chunks.add(chunk)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
|
||||
def release_chunk(self, tensor: torch.Tensor) -> None:
|
||||
if not self.enable_distributed_storage:
|
||||
|
@ -234,11 +245,17 @@ class ChunkManager:
|
|||
if chunk.can_release:
|
||||
chunk.release()
|
||||
self.accessed_chunks.remove(chunk)
|
||||
if chunk.is_free:
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
|
||||
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
if chunk.can_move_device:
|
||||
if chunk.data.device == device:
|
||||
return
|
||||
if chunk.can_move_device and not chunk.is_free:
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.move_device(device)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
|
@ -248,7 +265,9 @@ class ChunkManager:
|
|||
chunk = self.tensor_chunk_map[tensor]
|
||||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.reduce(is_all_reduce=not self.enable_distributed_storage)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
return True
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
|
||||
|
@ -272,6 +291,7 @@ class ChunkManager:
|
|||
|
||||
def __repr__(self) -> str:
|
||||
msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n'
|
||||
msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg += f'Group {group_name}:\n'
|
||||
for i, chunk in enumerate(group):
|
||||
|
|
|
@ -32,6 +32,8 @@ HAS_TENSORS = {
|
|||
}
|
||||
}
|
||||
|
||||
TOTAL_MEM = {True: {True: [8192, 8192], False: [16384, 16384]}, False: {True: [8192, 4096], False: [12288, 12288]}}
|
||||
|
||||
|
||||
@parameterize('use_chunk', [False, True])
|
||||
@parameterize('use_zero', [False, True])
|
||||
|
@ -42,15 +44,27 @@ def run_chunk_zero(use_chunk, use_zero):
|
|||
params = [torch.rand(32, 32) for _ in range(3)]
|
||||
chunk_size = 2048 if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
for p in params:
|
||||
chunk_manager.append_tensor(p, 'param')
|
||||
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
|
||||
for p in params:
|
||||
chunk_manager.access_chunk(p)
|
||||
check_has_params(params, [True, True, True])
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
|
||||
for p in params:
|
||||
chunk_manager.release_chunk(p)
|
||||
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
|
||||
for p in params:
|
||||
chunk_manager.move_chunk(p, torch.device('cpu'))
|
||||
assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
Loading…
Reference in New Issue