diff --git a/colossalai/tensor/chunk.py b/colossalai/tensor/chunk.py index 12cdb7957..b9f1ad466 100644 --- a/colossalai/tensor/chunk.py +++ b/colossalai/tensor/chunk.py @@ -54,6 +54,7 @@ class Chunk: if not self.is_src_rank: self.data.storage().resize_(0) self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} + self.mem = self.size * self.data.element_size() def append(self, tensor: torch.Tensor) -> None: assert tensor.dtype == self.dtype @@ -167,6 +168,10 @@ class Chunk: self.data.copy_(dest_chunk.data) self._update_tensors_ptr() + @property + def device_type(self) -> str: + return self.data.device.type + class ChunkManager: @@ -184,6 +189,7 @@ class ChunkManager: self.lazy_release_tensors: List[torch.Tensor] = [] if enable_distributed_storage and chunk_size is None: self.rank_load: Dict[str, torch.Tensor] = {} + self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None: assert tensor not in self.tensor_chunk_map @@ -202,6 +208,8 @@ class ChunkManager: self.rank_load[group_name][src_rank] += chunk_size self.chunk_groups[group_name].append(chunk) chunk.append(tensor) + if not chunk.is_free: + self.total_mem[chunk.device_type] += chunk.mem self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1] if not self.enable_distributed_storage: self.accessed_chunks.add(self.chunk_groups[group_name][-1]) @@ -222,8 +230,11 @@ class ChunkManager: chunk = self.tensor_chunk_map[tensor] if chunk in self.accessed_chunks: return + if not chunk.is_free: + self.total_mem[chunk.device_type] -= chunk.mem chunk.access() self.accessed_chunks.add(chunk) + self.total_mem[chunk.device_type] += chunk.mem def release_chunk(self, tensor: torch.Tensor) -> None: if not self.enable_distributed_storage: @@ -234,11 +245,17 @@ class ChunkManager: if chunk.can_release: chunk.release() self.accessed_chunks.remove(chunk) + if chunk.is_free: + self.total_mem[chunk.device_type] -= chunk.mem def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None: chunk = self.tensor_chunk_map[tensor] - if chunk.can_move_device: + if chunk.data.device == device: + return + if chunk.can_move_device and not chunk.is_free: + self.total_mem[chunk.device_type] -= chunk.mem chunk.move_device(device) + self.total_mem[chunk.device_type] += chunk.mem def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: chunk = self.tensor_chunk_map[tensor] @@ -248,7 +265,9 @@ class ChunkManager: chunk = self.tensor_chunk_map[tensor] if not chunk.can_reduce: return False + self.total_mem[chunk.device_type] -= chunk.mem chunk.reduce(is_all_reduce=not self.enable_distributed_storage) + self.total_mem[chunk.device_type] += chunk.mem return True def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: @@ -272,6 +291,7 @@ class ChunkManager: def __repr__(self) -> str: msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n' + msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' for group_name, group in self.chunk_groups.items(): msg += f'Group {group_name}:\n' for i, chunk in enumerate(group): diff --git a/tests/test_tensor/test_chunk.py b/tests/test_tensor/test_chunk.py index f367753de..515fb710c 100644 --- a/tests/test_tensor/test_chunk.py +++ b/tests/test_tensor/test_chunk.py @@ -32,6 +32,8 @@ HAS_TENSORS = { } } +TOTAL_MEM = {True: {True: [8192, 8192], False: [16384, 16384]}, False: {True: [8192, 4096], False: [12288, 12288]}} + @parameterize('use_chunk', [False, True]) @parameterize('use_zero', [False, True]) @@ -42,15 +44,27 @@ def run_chunk_zero(use_chunk, use_zero): params = [torch.rand(32, 32) for _ in range(3)] chunk_size = 2048 if use_chunk else None chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) + assert chunk_manager.total_mem['cpu'] == 0 + assert chunk_manager.total_mem['cuda'] == 0 for p in params: chunk_manager.append_tensor(p, 'param') check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank]) + assert chunk_manager.total_mem['cpu'] == 0 + assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank] for p in params: chunk_manager.access_chunk(p) check_has_params(params, [True, True, True]) + assert chunk_manager.total_mem['cpu'] == 0 + assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank] for p in params: chunk_manager.release_chunk(p) check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank]) + assert chunk_manager.total_mem['cpu'] == 0 + assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda'] + for p in params: + chunk_manager.move_chunk(p, torch.device('cpu')) + assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda'] + assert chunk_manager.total_mem['cuda'] == 0 def run_dist(rank, world_size, port):