diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py index 07ef74fbf..0885fb168 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -1,6 +1,7 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.utils.memory import colo_device_memory_used from colossalai.gemini.stateful_tensor import StatefulTensor +from colossalai.tensor import ChunkManager import torch import time @@ -128,3 +129,19 @@ class MemStatsCollector: self._start_flag = False self._step_idx = 0 self._step_total = 0 + + +class MemStatsCollectorV2(MemStatsCollector): + + def __init__(self, chunk_manager: ChunkManager) -> None: + super().__init__() + self._chunk_manager = chunk_manager + + def sample_model_data(self) -> None: + """Sampling model data statistics. + """ + if self._start_flag: + cuda_mem = self._chunk_manager.total_mem['cuda'] + cpu_mem = self._chunk_manager.total_mem['cpu'] + self._model_data_cuda_list.append(cuda_mem) + self._model_data_cpu_list.append(cpu_mem) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index e5b687248..63d275954 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -60,7 +60,7 @@ class ColoDDP(torch.nn.Module): else: ColoDDP._save_grad(p, grad) return empty_grad - + else: group = gpc.get_cpu_group(ParallelMode.DATA) dist.all_reduce(grad, group=group) @@ -113,7 +113,7 @@ class ColoDDPV2(ColoDDP): def _post_backward(self): self.chunk_manager.exec_lazy_release() for p in self.module.parameters(): - if self.chunk_manager.is_chunk_free(p) or not p.requires_grad: + if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad: p.grad = None else: p.grad = p.data @@ -137,8 +137,8 @@ class ColoDDPV2(ColoDDP): grad = grad / self.dp_world_size self.chunk_manager.copy_tensor_to_chunk_slice(p, grad) chunk = self.chunk_manager.get_chunk(p) - reduced = self.chunk_manager.reduce_chunk(p) - self.chunk_manager.release_chunk(p) + reduced = self.chunk_manager.reduce_chunk(chunk) + self.chunk_manager.release_chunk(chunk) if reduced and not chunk.is_free: self.overflow_counter += chunk.has_inf_or_nan return empty_grad diff --git a/colossalai/tensor/chunk.py b/colossalai/tensor/chunk.py index b9f1ad466..86a51282b 100644 --- a/colossalai/tensor/chunk.py +++ b/colossalai/tensor/chunk.py @@ -2,7 +2,7 @@ import torch import torch.distributed as dist from dataclasses import dataclass from enum import Enum -from typing import Optional, Dict, Deque, Set, List +from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable from collections import deque from colossalai.core import global_context as gpc from colossalai.context import ParallelMode @@ -172,6 +172,12 @@ class Chunk: def device_type(self) -> str: return self.data.device.type + def __hash__(self) -> int: + return hash(id(self)) + + def __eq__(self, __o: object) -> bool: + return self is __o + class ChunkManager: @@ -226,8 +232,7 @@ class ChunkManager: src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA) return src_rank - def access_chunk(self, tensor: torch.Tensor) -> None: - chunk = self.tensor_chunk_map[tensor] + def access_chunk(self, chunk: Chunk) -> None: if chunk in self.accessed_chunks: return if not chunk.is_free: @@ -236,10 +241,9 @@ class ChunkManager: self.accessed_chunks.add(chunk) self.total_mem[chunk.device_type] += chunk.mem - def release_chunk(self, tensor: torch.Tensor) -> None: + def release_chunk(self, chunk: Chunk) -> None: if not self.enable_distributed_storage: return - chunk = self.tensor_chunk_map[tensor] if chunk not in self.accessed_chunks: return if chunk.can_release: @@ -248,8 +252,7 @@ class ChunkManager: if chunk.is_free: self.total_mem[chunk.device_type] -= chunk.mem - def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None: - chunk = self.tensor_chunk_map[tensor] + def move_chunk(self, chunk: Chunk, device: torch.device) -> None: if chunk.data.device == device: return if chunk.can_move_device and not chunk.is_free: @@ -261,8 +264,7 @@ class ChunkManager: chunk = self.tensor_chunk_map[tensor] chunk.tensor_trans_state(tensor, state) - def reduce_chunk(self, tensor: torch.Tensor) -> bool: - chunk = self.tensor_chunk_map[tensor] + def reduce_chunk(self, chunk: Chunk) -> bool: if not chunk.can_reduce: return False self.total_mem[chunk.device_type] -= chunk.mem @@ -274,10 +276,6 @@ class ChunkManager: chunk = self.tensor_chunk_map[tensor] chunk.copy_tensor_to_chunk_slice(tensor, data) - def is_chunk_free(self, tensor: torch.Tensor) -> bool: - chunk = self.tensor_chunk_map[tensor] - return chunk.is_free - def get_chunk(self, tensor: torch.Tensor) -> Chunk: return self.tensor_chunk_map[tensor] @@ -285,8 +283,8 @@ class ChunkManager: self.lazy_release_tensors.extend(tensors) def exec_lazy_release(self) -> None: - for tensor in self.lazy_release_tensors: - self.release_chunk(tensor) + for chunk in self.get_chunks(self.lazy_release_tensors): + self.release_chunk(chunk) self.lazy_release_tensors.clear() def __repr__(self) -> str: @@ -340,3 +338,23 @@ class ChunkManager: for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]): if not dest_chunk.is_free: dest_chunk.copy_(src_chunk) + + def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: + chunks = [] + for tensor in tensors: + chunk = self.get_chunk(tensor) + if chunk not in chunks: + chunks.append(chunk) + return tuple(chunks) + + def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: + """Add extern static tensor to chunk manager. + Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them. + They are "static", which means their shape, dtype, device never change. + Thus, their memory usage never changes. + + Args: + tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. + """ + assert tensor not in self.tensor_chunk_map + self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() diff --git a/colossalai/zero/utils/zero_hook_v2.py b/colossalai/zero/utils/zero_hook_v2.py index 7b44ca1a6..737feedc4 100644 --- a/colossalai/zero/utils/zero_hook_v2.py +++ b/colossalai/zero/utils/zero_hook_v2.py @@ -20,12 +20,13 @@ class ZeROHookV2(ParamOpHook): self._training_phase = TrainingPhase.FORWARD def pre_op(self, params): + chunks = self._chunk_manager.get_chunks(params) for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._chunk_manager.exec_lazy_release() # TODO: evict chunks - for p in params: - self._chunk_manager.access_chunk(p) + for chunk in chunks: + self._chunk_manager.access_chunk(chunk) def post_op(self, params): for p in params: diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/zero/zero_optimizer.py index 8cc21515d..0f84b4d9f 100644 --- a/colossalai/zero/zero_optimizer.py +++ b/colossalai/zero/zero_optimizer.py @@ -48,7 +48,7 @@ class ZeroOptimizer(ColossalaiOptimizer): def _update_params_ptr(self): for group in self.optim.param_groups: for p in group['params']: - if not self.module.chunk_manager.is_chunk_free(p): + if not self.module.chunk_manager.get_chunk(p).is_free: p.data = self.fp16_param_to_fp32_param[p] else: assert p.grad is None diff --git a/tests/test_tensor/test_chunk.py b/tests/test_tensor/test_chunk.py index 515fb710c..f1d508d83 100644 --- a/tests/test_tensor/test_chunk.py +++ b/tests/test_tensor/test_chunk.py @@ -32,7 +32,7 @@ HAS_TENSORS = { } } -TOTAL_MEM = {True: {True: [8192, 8192], False: [16384, 16384]}, False: {True: [8192, 4096], False: [12288, 12288]}} +TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}} @parameterize('use_chunk', [False, True]) @@ -41,8 +41,8 @@ def run_chunk_zero(use_chunk, use_zero): rank = gpc.get_local_rank(ParallelMode.DATA) if rank == 0: print(f'use_chunk={use_chunk}, use_zero={use_zero}') - params = [torch.rand(32, 32) for _ in range(3)] - chunk_size = 2048 if use_chunk else None + params = [torch.rand(8, 8) for _ in range(3)] + chunk_size = 128 if use_chunk else None chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cuda'] == 0 @@ -51,18 +51,19 @@ def run_chunk_zero(use_chunk, use_zero): check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank]) assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank] - for p in params: - chunk_manager.access_chunk(p) + chunks = chunk_manager.get_chunks(params) + for chunk in chunks: + chunk_manager.access_chunk(chunk) check_has_params(params, [True, True, True]) assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank] - for p in params: - chunk_manager.release_chunk(p) + for chunk in chunks: + chunk_manager.release_chunk(chunk) check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank]) assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda'] - for p in params: - chunk_manager.move_chunk(p, torch.device('cpu')) + for chunk in chunks: + chunk_manager.move_chunk(chunk, torch.device('cpu')) assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda'] assert chunk_manager.total_mem['cuda'] == 0