[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
pull/5722/head
pre-commit-ci[bot] 2024-05-16 07:26:19 +00:00
parent 82b25524ff
commit 6bbe956316
4 changed files with 32 additions and 17 deletions

View File

@ -131,7 +131,7 @@ class GeminiDDP(ModelWrapper):
offload_param_frac=offload_param_frac, offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio, warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio, steady_cuda_cap_ratio=steady_cuda_cap_ratio,
max_prefetch=max_prefetch max_prefetch=max_prefetch,
) )
self.force_outputs_fp32 = force_outputs_fp32 self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.param_op_hook = GeminiZeROHook(self.gemini_manager)

View File

@ -1,10 +1,9 @@
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Dict, List, Iterable, Tuple from typing import List
import torch import torch
import torch.distributed as dist
from colossalai.logging import DistributedLogger from colossalai.logging import DistributedLogger
from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.tensor.param_op_hook import ColoParamOpHook
@ -12,8 +11,6 @@ from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState from colossalai.zero.gemini import TensorState
from colossalai.zero.gemini.gemini_mgr import GeminiManager from colossalai.zero.gemini.gemini_mgr import GeminiManager
from .chunk import Chunk
class TrainingPhase(Enum): class TrainingPhase(Enum):
FORWARD = 0 FORWARD = 0
@ -23,7 +20,9 @@ class TrainingPhase(Enum):
logger = DistributedLogger("gemini_hook") logger = DistributedLogger("gemini_hook")
import os import os
rank = int(os.environ['RANK'])
rank = int(os.environ["RANK"])
class GeminiZeROHook(ColoParamOpHook): class GeminiZeROHook(ColoParamOpHook):
def __init__(self, gemini_manager: GeminiManager) -> None: def __init__(self, gemini_manager: GeminiManager) -> None:
@ -32,14 +31,13 @@ class GeminiZeROHook(ColoParamOpHook):
self._chunk_manager = gemini_manager.chunk_manager self._chunk_manager = gemini_manager.chunk_manager
self._training_phase = TrainingPhase.FORWARD self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params): def pre_op(self, params):
# map params to chunks # map params to chunks
params = [p for p in params if not is_ddp_ignored(p)] params = [p for p in params if not is_ddp_ignored(p)]
all_chunks = self._chunk_manager.get_chunks(params) all_chunks = self._chunk_manager.get_chunks(params)
# wait for prefetched chunks, filter those are not prefetched # wait for prefetched chunks, filter those are not prefetched
unique_chunks = set(all_chunks) set(all_chunks)
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks)
# transfer state # transfer state
@ -48,7 +46,9 @@ class GeminiZeROHook(ColoParamOpHook):
self._gemini_manager.sample_overall_data() self._gemini_manager.sample_overall_data()
# evit chunks, aware of async fetched # evit chunks, aware of async fetched
self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0) self._gemini_manager.adjust_layout(
all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0
)
# fetch the rest synchronously # fetch the rest synchronously
for chunk in chunks_fetch_sync: for chunk in chunks_fetch_sync:
@ -57,7 +57,9 @@ class GeminiZeROHook(ColoParamOpHook):
# get possible chunks to prefetch # get possible chunks to prefetch
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks() chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks()
if rank == 0 and not self._gemini_manager.is_warmup(): if rank == 0 and not self._gemini_manager.is_warmup():
print(f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}") print(
f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}"
)
print(f"{all_chunks=}") print(f"{all_chunks=}")
print(f"accessed_chunks={self._chunk_manager.accessed_chunks}") print(f"accessed_chunks={self._chunk_manager.accessed_chunks}")
print(f"{chunks_fetch_sync=}") print(f"{chunks_fetch_sync=}")

View File

@ -1,6 +1,6 @@
import functools import functools
from time import time from time import time
from typing import Dict, List, Optional, Tuple, Iterable from typing import Dict, Iterable, List, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -101,7 +101,7 @@ class GeminiManager:
start = time() start = time()
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway) self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks) cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks)
# don't evict chunks that are asynchronously fetched # don't evict chunks that are asynchronously fetched
can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works] can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works]
self._layout_time += time() - start self._layout_time += time() - start

View File

@ -13,11 +13,17 @@ from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector from .memory_tracer import ChunkMemStatsCollector
class PlacementPolicy(ABC): class PlacementPolicy(ABC):
need_mem_stats: bool = False need_mem_stats: bool = False
def __init__( def __init__(
self, gemini_manager: 'GeminiManager', chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch:int = 0, **kwargs self,
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
**kwargs,
) -> None: ) -> None:
self.gemini_manager = gemini_manager self.gemini_manager = gemini_manager
self.chunk_manager = chunk_manager self.chunk_manager = chunk_manager
@ -38,13 +44,16 @@ class PlacementPolicy(ABC):
def get_prefetch_chunks(self) -> List[Chunk]: def get_prefetch_chunks(self) -> List[Chunk]:
raise NotImplementedError raise NotImplementedError
import os import os
rank = int(os.environ["RANK"]) rank = int(os.environ["RANK"])
class StaticPlacementPolicy(PlacementPolicy): class StaticPlacementPolicy(PlacementPolicy):
def __init__( def __init__(
self, self,
gemini_manager: 'GeminiManager', gemini_manager: "GeminiManager",
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0, max_prefetch: int = 0,
@ -53,7 +62,9 @@ class StaticPlacementPolicy(PlacementPolicy):
offload_param_frac: float = 0.0, offload_param_frac: float = 0.0,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) super().__init__(
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
offload_param_frac = 0.0 offload_param_frac = 0.0
@ -124,7 +135,7 @@ class AutoPlacementPolicy(PlacementPolicy):
def __init__( def __init__(
self, self,
gemini_manager: 'GeminiManager', gemini_manager: "GeminiManager",
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0, max_prefetch: int = 0,
@ -132,7 +143,9 @@ class AutoPlacementPolicy(PlacementPolicy):
steady_cuda_cap_ratio: float = 0.9, steady_cuda_cap_ratio: float = 0.9,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) super().__init__(
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio()