From 3d625ca83656793d81eaece4c49967bd4fafcf7d Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Fri, 17 May 2024 10:55:28 +0000 Subject: [PATCH 1/9] add some todo Message --- colossalai/zero/gemini/chunk/manager.py | 2 +- colossalai/zero/gemini/gemini_hook.py | 5 ++++- colossalai/zero/gemini/gemini_mgr.py | 3 ++- colossalai/zero/gemini/placement_policy.py | 12 ++++++++---- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index c7bdd5e1f..341790a72 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -83,7 +83,7 @@ class ChunkManager: if chunk_group: # the chunk group is not empty # close the last chunk - self.__close_one_chunk(chunk_group[-1]) + self.__close_one_chunk(chunk_group[-1]) # chunk[-1] 满了,所以关闭,不能再添加,然后同时scatter到ZeRO PG中 if tensor.numel() > chunk_size: chunk_size = tensor.numel() diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 27a19c132..bf990d127 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -33,19 +33,22 @@ class GeminiZeROHook(ColoParamOpHook): all_chunks = self._chunk_manager.get_chunks(params) # wait for prefetched chunks, filter those are not prefetched - chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) + chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) # 当前要fetch的chunk # transfer state for p in params: + # TODO(haze188): check状态转换 self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() # evit chunks, aware of async fetched + # TODO(haze188): 可能我们prefetch的又被淘汰掉, check一下 self._gemini_manager.adjust_layout( all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 ) # fetch the rest synchronously + # TODO(haze188): 1. 先prefetch还是先fetch(prefetch是异步,fetch是同步) for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 11bde789c..2e96c22f3 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -125,7 +125,7 @@ class GeminiManager: self._async_works[chunk].wait() del self._async_works[chunk] else: - non_prefetched_chunks.append(chunk) + non_prefetched_chunks.append(chunk) # 没在之前prefetch过,现在要prefetch的chunk return tuple(non_prefetched_chunks) def add_work(self, chunk: Chunk, work: dist.Work): @@ -154,6 +154,7 @@ class GeminiManager: def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: self._compute_idx += 1 + # TODO(haze188): _compute_list 记录块的访问顺序 if self._warmup and (self._placement_policy.need_mem_stats or record_anyway): self._compute_list.append(chunks) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index c0f92fa50..4c3d8dbe2 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -45,9 +45,9 @@ class PlacementPolicy(ABC): raise NotImplementedError -import os - -rank = int(os.environ["RANK"]) +# import torch.distributed as dist +# # rank = int(os.environ["RANK"]) +# rank = dist.get_rank() class StaticPlacementPolicy(PlacementPolicy): @@ -118,8 +118,10 @@ class StaticPlacementPolicy(PlacementPolicy): def get_prefetch_chunks(self) -> List[Chunk]: if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list return [] + # 最多有多少个异步的work can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) prefetch = [] + # static炸就炸了,dynamic可能需要我们要先分析当前运行时的内存情况,分配空间或者淘汰块 for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): for chunk in self.gemini_manager.compute_list[i]: if len(prefetch) >= can_prefetch: @@ -238,7 +240,9 @@ class AutoPlacementPolicy(PlacementPolicy): grads_device_map[p] = torch.device("cpu") def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: - return [] # TODO @botbw: implement prefetching for auto + # TODO @haze188 @botbw: implement prefetching for auto + + return [] class PlacementPolicyFactory: From 06a3a100b330d10b28615af285642c6667ba8c23 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Fri, 17 May 2024 10:57:49 +0000 Subject: [PATCH 2/9] remove unrelated code --- colossalai/zero/gemini/placement_policy.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 4c3d8dbe2..c0d03ba3b 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -45,11 +45,6 @@ class PlacementPolicy(ABC): raise NotImplementedError -# import torch.distributed as dist -# # rank = int(os.environ["RANK"]) -# rank = dist.get_rank() - - class StaticPlacementPolicy(PlacementPolicy): def __init__( self, From d22bf30ca645aac20265373fe34db281db6abb2e Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 04:01:53 +0000 Subject: [PATCH 3/9] implement auto policy prefetch and modify a little origin code. --- colossalai/zero/gemini/placement_policy.py | 36 ++++++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index c0d03ba3b..9803d7f6d 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -11,6 +11,7 @@ from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager +from .gemini_mgr import GeminiManager from .memory_tracer import ChunkMemStatsCollector @@ -123,8 +124,9 @@ class StaticPlacementPolicy(PlacementPolicy): break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: prefetch.append(chunk) - if len(prefetch) >= can_prefetch: - break + else: + continue + break return prefetch @@ -133,7 +135,7 @@ class AutoPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: "GeminiManager", + gemini_manager: GeminiManager, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -234,10 +236,32 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") - def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: - # TODO @haze188 @botbw: implement prefetching for auto + def get_prefetch_chunks(self) -> List[Chunk]: + if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + return [] + # modified from self.evict_tensors + cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( + get_accelerator().get_current_device() + ) + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") + used_cuda_model_data = self.chunk_manager.total_mem["cuda"] + total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period + avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data - return [] + prefetch_chunk_memory = 0 + can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) + prefetch = [] + for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): + for chunk in self.gemini_manager.compute_list[i]: + chunk: Chunk + if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data: + break + if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: + prefetch.append(chunk) + else: + continue + break + return prefetch class PlacementPolicyFactory: From df63db7e63d951017fd1fa797fbcaec259fba644 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 05:15:51 +0000 Subject: [PATCH 4/9] remote comments --- colossalai/zero/gemini/gemini_hook.py | 4 +- colossalai/zero/gemini/gemini_mgr.py | 1 - examples/language/gpt/gemini/demo.ipynb | 142 ++++++++++++++++++ examples/language/gpt/gemini/run_gemini.sh | 2 +- .../language/gpt/gemini/train_gpt_demo.py | 6 +- tests/test_zero/test_gemini/test_optim.py | 2 +- 6 files changed, 148 insertions(+), 9 deletions(-) create mode 100644 examples/language/gpt/gemini/demo.ipynb diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index bf990d127..e691b423b 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -37,18 +37,16 @@ class GeminiZeROHook(ColoParamOpHook): # transfer state for p in params: - # TODO(haze188): check状态转换 self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() # evit chunks, aware of async fetched - # TODO(haze188): 可能我们prefetch的又被淘汰掉, check一下 + # TODO: check if prefetched chunks will be evicted self._gemini_manager.adjust_layout( all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 ) # fetch the rest synchronously - # TODO(haze188): 1. 先prefetch还是先fetch(prefetch是异步,fetch是同步) for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 2e96c22f3..85beafd32 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -154,7 +154,6 @@ class GeminiManager: def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: self._compute_idx += 1 - # TODO(haze188): _compute_list 记录块的访问顺序 if self._warmup and (self._placement_policy.need_mem_stats or record_anyway): self._compute_list.append(chunks) diff --git a/examples/language/gpt/gemini/demo.ipynb b/examples/language/gpt/gemini/demo.ipynb new file mode 100644 index 000000000..09953b3a9 --- /dev/null +++ b/examples/language/gpt/gemini/demo.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(in_features=10, out_features=5, bias=False) 50\n", + "Linear(in_features=5, out_features=10, bias=False) 50\n", + "Linear(in_features=10, out_features=10, bias=False) 100\n" + ] + } + ], + "source": [ + "class Toy(nn.Module):\n", + " \n", + " def __init__(self):\n", + " super(Toy, self).__init__()\n", + " self.fc1 = nn.Linear(10,5, bias=False)\n", + " self.m3 = nn.Sequential(nn.Linear(5, 10, bias=False), nn.Linear(10,10, bias=False))\n", + "\n", + "t = Toy()\n", + "for mod in t.modules():\n", + " for p in mod.parameters(recurse=False):\n", + " print(mod, p.numel())" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([5, 10]) 50\n", + "torch.Size([10, 5]) 50\n", + "torch.Size([10, 10]) 100\n" + ] + } + ], + "source": [ + "for p in t.parameters():\n", + " print(p.shape, p.numel())" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'224'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "conf_str = torch.__config__.parallel_info()\n", + "inter_str = conf_str.split(\"hardware_concurrency() : \")[1]\n", + "max_concurrency = inter_str.split(\"\\n\")[0]\n", + "max_concurrency" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 0\n", + "0 1\n", + "0 2\n", + "1 0\n", + "1 1\n", + "1 2\n" + ] + } + ], + "source": [ + "for i in range(3):\n", + " for j in range(3):\n", + " print(i, j)\n", + " if i == 1 and j == 2:break\n", + " else:\n", + " continue\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "colossalai-py310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh index 5eaa4af4d..bffd26f59 100644 --- a/examples/language/gpt/gemini/run_gemini.sh +++ b/examples/language/gpt/gemini/run_gemini.sh @@ -6,7 +6,7 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} export GPUNUM=${GPUNUM:-1} export BATCH_SIZE=${BATCH_SIZE:-16} export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} -export TRAIN_STEP=${TRAIN_STEP:-10} +export TRAIN_STEP=${TRAIN_STEP:-2} # export PYTHONPATH=$PWD:$PYTHONPATH diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 6db74231a..667a0c77a 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -66,18 +66,18 @@ class GPTLMLoss(nn.Module): def get_cpu_mem(): - return psutil.Process().memory_info().rss / 1024**2 + return psutil.Process().memory_info().rss / 1024**2 # 返回值是B,转换成MB def get_gpu_mem(): - return torch.cuda.memory_allocated() / 1024**2 + return torch.cuda.memory_allocated() / 1024**2 # 转换成MB def get_mem_info(prefix=""): return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" -def get_model_size(model: nn.Module): +def get_model_size(model: nn.Module): # 得到模型参数量 total_numel = 0 for module in model.modules(): for p in module.parameters(recurse=False): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 1c914ca0e..4e1fb988b 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -26,7 +26,7 @@ PLACEMENT_CONFIGS = [ "offload_optim_frac": 1.0, "offload_param_frac": 1.0, }, # zero3-offload-all - {"placement_policy": "auto"}, + # {"placement_policy": "auto"}, ] # this model is large enough to slice to chunks From 5c6c5d6be316a4f4e867d0d8049b508e0d59ad6c Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 05:15:51 +0000 Subject: [PATCH 5/9] remove comments --- colossalai/zero/gemini/gemini_hook.py | 4 +- colossalai/zero/gemini/gemini_mgr.py | 1 - examples/language/gpt/gemini/demo.ipynb | 142 ++++++++++++++++++ examples/language/gpt/gemini/run_gemini.sh | 2 +- .../language/gpt/gemini/train_gpt_demo.py | 6 +- tests/test_zero/test_gemini/test_optim.py | 2 +- 6 files changed, 148 insertions(+), 9 deletions(-) create mode 100644 examples/language/gpt/gemini/demo.ipynb diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index bf990d127..e691b423b 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -37,18 +37,16 @@ class GeminiZeROHook(ColoParamOpHook): # transfer state for p in params: - # TODO(haze188): check状态转换 self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() # evit chunks, aware of async fetched - # TODO(haze188): 可能我们prefetch的又被淘汰掉, check一下 + # TODO: check if prefetched chunks will be evicted self._gemini_manager.adjust_layout( all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 ) # fetch the rest synchronously - # TODO(haze188): 1. 先prefetch还是先fetch(prefetch是异步,fetch是同步) for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 2e96c22f3..85beafd32 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -154,7 +154,6 @@ class GeminiManager: def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: self._compute_idx += 1 - # TODO(haze188): _compute_list 记录块的访问顺序 if self._warmup and (self._placement_policy.need_mem_stats or record_anyway): self._compute_list.append(chunks) diff --git a/examples/language/gpt/gemini/demo.ipynb b/examples/language/gpt/gemini/demo.ipynb new file mode 100644 index 000000000..09953b3a9 --- /dev/null +++ b/examples/language/gpt/gemini/demo.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(in_features=10, out_features=5, bias=False) 50\n", + "Linear(in_features=5, out_features=10, bias=False) 50\n", + "Linear(in_features=10, out_features=10, bias=False) 100\n" + ] + } + ], + "source": [ + "class Toy(nn.Module):\n", + " \n", + " def __init__(self):\n", + " super(Toy, self).__init__()\n", + " self.fc1 = nn.Linear(10,5, bias=False)\n", + " self.m3 = nn.Sequential(nn.Linear(5, 10, bias=False), nn.Linear(10,10, bias=False))\n", + "\n", + "t = Toy()\n", + "for mod in t.modules():\n", + " for p in mod.parameters(recurse=False):\n", + " print(mod, p.numel())" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([5, 10]) 50\n", + "torch.Size([10, 5]) 50\n", + "torch.Size([10, 10]) 100\n" + ] + } + ], + "source": [ + "for p in t.parameters():\n", + " print(p.shape, p.numel())" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'224'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "conf_str = torch.__config__.parallel_info()\n", + "inter_str = conf_str.split(\"hardware_concurrency() : \")[1]\n", + "max_concurrency = inter_str.split(\"\\n\")[0]\n", + "max_concurrency" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 0\n", + "0 1\n", + "0 2\n", + "1 0\n", + "1 1\n", + "1 2\n" + ] + } + ], + "source": [ + "for i in range(3):\n", + " for j in range(3):\n", + " print(i, j)\n", + " if i == 1 and j == 2:break\n", + " else:\n", + " continue\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "colossalai-py310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh index 5eaa4af4d..bffd26f59 100644 --- a/examples/language/gpt/gemini/run_gemini.sh +++ b/examples/language/gpt/gemini/run_gemini.sh @@ -6,7 +6,7 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} export GPUNUM=${GPUNUM:-1} export BATCH_SIZE=${BATCH_SIZE:-16} export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} -export TRAIN_STEP=${TRAIN_STEP:-10} +export TRAIN_STEP=${TRAIN_STEP:-2} # export PYTHONPATH=$PWD:$PYTHONPATH diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 6db74231a..667a0c77a 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -66,18 +66,18 @@ class GPTLMLoss(nn.Module): def get_cpu_mem(): - return psutil.Process().memory_info().rss / 1024**2 + return psutil.Process().memory_info().rss / 1024**2 # 返回值是B,转换成MB def get_gpu_mem(): - return torch.cuda.memory_allocated() / 1024**2 + return torch.cuda.memory_allocated() / 1024**2 # 转换成MB def get_mem_info(prefix=""): return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" -def get_model_size(model: nn.Module): +def get_model_size(model: nn.Module): # 得到模型参数量 total_numel = 0 for module in model.modules(): for p in module.parameters(recurse=False): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 1c914ca0e..4e1fb988b 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -26,7 +26,7 @@ PLACEMENT_CONFIGS = [ "offload_optim_frac": 1.0, "offload_param_frac": 1.0, }, # zero3-offload-all - {"placement_policy": "auto"}, + # {"placement_policy": "auto"}, ] # this model is large enough to slice to chunks From 1ec92d29af16fcfc1b641e61eded877c5680cc47 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 05:21:26 +0000 Subject: [PATCH 6/9] remove perf log, unrelated file and so on --- colossalai/zero/gemini/chunk/manager.py | 2 +- colossalai/zero/gemini/gemini_hook.py | 2 +- colossalai/zero/gemini/gemini_mgr.py | 2 +- colossalai/zero/gemini/placement_policy.py | 2 - examples/language/gpt/gemini/demo.ipynb | 142 ------------------ .../language/gpt/gemini/train_gpt_demo.py | 6 +- tests/test_zero/test_gemini/test_optim.py | 2 +- 7 files changed, 7 insertions(+), 151 deletions(-) delete mode 100644 examples/language/gpt/gemini/demo.ipynb diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 341790a72..c7bdd5e1f 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -83,7 +83,7 @@ class ChunkManager: if chunk_group: # the chunk group is not empty # close the last chunk - self.__close_one_chunk(chunk_group[-1]) # chunk[-1] 满了,所以关闭,不能再添加,然后同时scatter到ZeRO PG中 + self.__close_one_chunk(chunk_group[-1]) if tensor.numel() > chunk_size: chunk_size = tensor.numel() diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index e691b423b..450cb3ad6 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -33,7 +33,7 @@ class GeminiZeROHook(ColoParamOpHook): all_chunks = self._chunk_manager.get_chunks(params) # wait for prefetched chunks, filter those are not prefetched - chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) # 当前要fetch的chunk + chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) # transfer state for p in params: diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 85beafd32..11bde789c 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -125,7 +125,7 @@ class GeminiManager: self._async_works[chunk].wait() del self._async_works[chunk] else: - non_prefetched_chunks.append(chunk) # 没在之前prefetch过,现在要prefetch的chunk + non_prefetched_chunks.append(chunk) return tuple(non_prefetched_chunks) def add_work(self, chunk: Chunk, work: dist.Work): diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 9e9fb1f58..cfbf16d1b 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -113,10 +113,8 @@ class StaticPlacementPolicy(PlacementPolicy): def get_prefetch_chunks(self) -> List[Chunk]: if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list return [] - # 最多有多少个异步的work can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) prefetch = [] - # static炸就炸了,dynamic可能需要我们要先分析当前运行时的内存情况,分配空间或者淘汰块 for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): for chunk in self.gemini_manager.compute_list[i]: if len(prefetch) >= can_prefetch: diff --git a/examples/language/gpt/gemini/demo.ipynb b/examples/language/gpt/gemini/demo.ipynb deleted file mode 100644 index 09953b3a9..000000000 --- a/examples/language/gpt/gemini/demo.ipynb +++ /dev/null @@ -1,142 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Linear(in_features=10, out_features=5, bias=False) 50\n", - "Linear(in_features=5, out_features=10, bias=False) 50\n", - "Linear(in_features=10, out_features=10, bias=False) 100\n" - ] - } - ], - "source": [ - "class Toy(nn.Module):\n", - " \n", - " def __init__(self):\n", - " super(Toy, self).__init__()\n", - " self.fc1 = nn.Linear(10,5, bias=False)\n", - " self.m3 = nn.Sequential(nn.Linear(5, 10, bias=False), nn.Linear(10,10, bias=False))\n", - "\n", - "t = Toy()\n", - "for mod in t.modules():\n", - " for p in mod.parameters(recurse=False):\n", - " print(mod, p.numel())" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([5, 10]) 50\n", - "torch.Size([10, 5]) 50\n", - "torch.Size([10, 10]) 100\n" - ] - } - ], - "source": [ - "for p in t.parameters():\n", - " print(p.shape, p.numel())" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'224'" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "conf_str = torch.__config__.parallel_info()\n", - "inter_str = conf_str.split(\"hardware_concurrency() : \")[1]\n", - "max_concurrency = inter_str.split(\"\\n\")[0]\n", - "max_concurrency" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0 0\n", - "0 1\n", - "0 2\n", - "1 0\n", - "1 1\n", - "1 2\n" - ] - } - ], - "source": [ - "for i in range(3):\n", - " for j in range(3):\n", - " print(i, j)\n", - " if i == 1 and j == 2:break\n", - " else:\n", - " continue\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "colossalai-py310", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 667a0c77a..6db74231a 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -66,18 +66,18 @@ class GPTLMLoss(nn.Module): def get_cpu_mem(): - return psutil.Process().memory_info().rss / 1024**2 # 返回值是B,转换成MB + return psutil.Process().memory_info().rss / 1024**2 def get_gpu_mem(): - return torch.cuda.memory_allocated() / 1024**2 # 转换成MB + return torch.cuda.memory_allocated() / 1024**2 def get_mem_info(prefix=""): return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" -def get_model_size(model: nn.Module): # 得到模型参数量 +def get_model_size(model: nn.Module): total_numel = 0 for module in model.modules(): for p in module.parameters(recurse=False): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 4e1fb988b..1c914ca0e 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -26,7 +26,7 @@ PLACEMENT_CONFIGS = [ "offload_optim_frac": 1.0, "offload_param_frac": 1.0, }, # zero3-offload-all - # {"placement_policy": "auto"}, + {"placement_policy": "auto"}, ] # this model is large enough to slice to chunks From a280517dd9618247deaea729b4f1aaddbc17995c Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 05:25:35 +0000 Subject: [PATCH 7/9] remove unrelated file --- examples/language/gpt/gemini/demo.ipynb | 142 ------------------ .../language/gpt/gemini/train_gpt_demo.py | 7 +- tests/test_zero/test_gemini/test_optim.py | 2 +- 3 files changed, 5 insertions(+), 146 deletions(-) delete mode 100644 examples/language/gpt/gemini/demo.ipynb diff --git a/examples/language/gpt/gemini/demo.ipynb b/examples/language/gpt/gemini/demo.ipynb deleted file mode 100644 index 09953b3a9..000000000 --- a/examples/language/gpt/gemini/demo.ipynb +++ /dev/null @@ -1,142 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Linear(in_features=10, out_features=5, bias=False) 50\n", - "Linear(in_features=5, out_features=10, bias=False) 50\n", - "Linear(in_features=10, out_features=10, bias=False) 100\n" - ] - } - ], - "source": [ - "class Toy(nn.Module):\n", - " \n", - " def __init__(self):\n", - " super(Toy, self).__init__()\n", - " self.fc1 = nn.Linear(10,5, bias=False)\n", - " self.m3 = nn.Sequential(nn.Linear(5, 10, bias=False), nn.Linear(10,10, bias=False))\n", - "\n", - "t = Toy()\n", - "for mod in t.modules():\n", - " for p in mod.parameters(recurse=False):\n", - " print(mod, p.numel())" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([5, 10]) 50\n", - "torch.Size([10, 5]) 50\n", - "torch.Size([10, 10]) 100\n" - ] - } - ], - "source": [ - "for p in t.parameters():\n", - " print(p.shape, p.numel())" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'224'" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "conf_str = torch.__config__.parallel_info()\n", - "inter_str = conf_str.split(\"hardware_concurrency() : \")[1]\n", - "max_concurrency = inter_str.split(\"\\n\")[0]\n", - "max_concurrency" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0 0\n", - "0 1\n", - "0 2\n", - "1 0\n", - "1 1\n", - "1 2\n" - ] - } - ], - "source": [ - "for i in range(3):\n", - " for j in range(3):\n", - " print(i, j)\n", - " if i == 1 and j == 2:break\n", - " else:\n", - " continue\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "colossalai-py310", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 667a0c77a..bf1be87ba 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -66,18 +66,19 @@ class GPTLMLoss(nn.Module): def get_cpu_mem(): - return psutil.Process().memory_info().rss / 1024**2 # 返回值是B,转换成MB + return psutil.Process().memory_info().rss / 1024**2 # MB unit def get_gpu_mem(): - return torch.cuda.memory_allocated() / 1024**2 # 转换成MB + return torch.cuda.memory_allocated() / 1024**2 # MB unit def get_mem_info(prefix=""): return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" -def get_model_size(model: nn.Module): # 得到模型参数量 +def get_model_size(model: nn.Module): + # get the number of parameter of the model total_numel = 0 for module in model.modules(): for p in module.parameters(recurse=False): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 4e1fb988b..1c914ca0e 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -26,7 +26,7 @@ PLACEMENT_CONFIGS = [ "offload_optim_frac": 1.0, "offload_param_frac": 1.0, }, # zero3-offload-all - # {"placement_policy": "auto"}, + {"placement_policy": "auto"}, ] # this model is large enough to slice to chunks From bfcb2d1ff8dee52746f9d7af76ffe3acf0312ea5 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 07:25:24 +0000 Subject: [PATCH 8/9] refactor the code structure to solve the circular import --- colossalai/zero/gemini/gemini_hook.py | 7 +++- colossalai/zero/gemini/gemini_mgr.py | 6 ++- colossalai/zero/gemini/placement_policy.py | 44 +++++++++++----------- 3 files changed, 32 insertions(+), 25 deletions(-) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 450cb3ad6..315730f7a 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -51,7 +51,12 @@ class GeminiZeROHook(ColoParamOpHook): self._chunk_manager.access_chunk(chunk) # get possible chunks to prefetch - chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks() + chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks( + is_warmup=self._gemini_manager.is_warmup(), + compute_list=self._gemini_manager.compute_list, + compute_idx=self._gemini_manager.compute_idx, + async_works=self._gemini_manager.async_works, + ) # prefetch for chunk in chunks_fetch_async: diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 11bde789c..5b309c7a1 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -45,7 +45,7 @@ class GeminiManager: self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 - self._async_works: Dict[Chunk, dist.work] = {} + self._async_works: Dict[Chunk, dist.Work] = {} self._h2d_volume = 0 self._d2h_volume = 0 @@ -183,6 +183,10 @@ class GeminiManager: def compute_idx(self) -> int: return self._compute_idx + @property + def async_works(self) -> Dict[Chunk, dist.Work]: + return self._async_works + @property def placement_policy(self) -> PlacementPolicy: return self._placement_policy diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index cfbf16d1b..9b1d1a6ab 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -5,13 +5,13 @@ from time import time from typing import Dict, List, Optional, Tuple, Type import torch +import torch.distributed as dist from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager -from .gemini_mgr import GeminiManager from .memory_tracer import ChunkMemStatsCollector @@ -20,13 +20,11 @@ class PlacementPolicy(ABC): def __init__( self, - gemini_manager: "GeminiManager", # TODO @botbw: solve circular import chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, **kwargs, ) -> None: - self.gemini_manager = gemini_manager self.chunk_manager = chunk_manager self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector self.max_prefetch = max_prefetch @@ -41,14 +39,15 @@ class PlacementPolicy(ABC): ) -> None: raise NotImplementedError - def get_prefetch_chunks(self) -> List[Chunk]: + def get_prefetch_chunks( + self, is_warmup, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] + ) -> List[Chunk]: return [] # no prefetch by default class StaticPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: "GeminiManager", chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -57,9 +56,7 @@ class StaticPlacementPolicy(PlacementPolicy): offload_param_frac: float = 0.0, **kwargs, ) -> None: - super().__init__( - gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch - ) + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") offload_param_frac = 0.0 @@ -110,13 +107,15 @@ class StaticPlacementPolicy(PlacementPolicy): self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) - def get_prefetch_chunks(self) -> List[Chunk]: - if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + def get_prefetch_chunks( + self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] + ) -> List[Chunk]: + if is_warmup: # no prefetch during warmup since we need compute_list return [] - can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) + can_prefetch = self.max_prefetch - len(async_works) prefetch = [] - for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): - for chunk in self.gemini_manager.compute_list[i]: + for i in range(compute_idx + 1, len(compute_list)): + for chunk in compute_list[i]: if len(prefetch) >= can_prefetch: break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: @@ -132,7 +131,6 @@ class AutoPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: GeminiManager, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -140,9 +138,7 @@ class AutoPlacementPolicy(PlacementPolicy): steady_cuda_cap_ratio: float = 0.9, **kwargs, ) -> None: - super().__init__( - gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch - ) + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() @@ -233,8 +229,10 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") - def get_prefetch_chunks(self) -> List[Chunk]: - if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + def get_prefetch_chunks( + self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] + ) -> List[Chunk]: + if is_warmup: # no prefetch during warmup since we need compute_list return [] # modified from self.evict_tensors cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( @@ -246,14 +244,14 @@ class AutoPlacementPolicy(PlacementPolicy): avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data prefetch_chunk_memory = 0 - can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) + can_prefetch = self.max_prefetch - len(async_works) prefetch = [] - for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): - for chunk in self.gemini_manager.compute_list[i]: - chunk: Chunk + for i in range(compute_idx + 1, len(compute_list)): + for chunk in compute_list[i]: if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data: break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: + prefetch_chunk_memory += chunk.chunk_mem prefetch.append(chunk) else: continue From 90d8d0183c39832cc2a5951d4d4437a69e878a18 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 07:28:20 +0000 Subject: [PATCH 9/9] remove personal comments --- colossalai/zero/gemini/gemini_hook.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 315730f7a..cab26c822 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -41,7 +41,6 @@ class GeminiZeROHook(ColoParamOpHook): self._gemini_manager.sample_overall_data() # evit chunks, aware of async fetched - # TODO: check if prefetched chunks will be evicted self._gemini_manager.adjust_layout( all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 )