refactor the code structure to solve the circular import

pull/5733/head
genghaozhe 6 months ago
parent a280517dd9
commit bfcb2d1ff8

@ -51,7 +51,12 @@ class GeminiZeROHook(ColoParamOpHook):
self._chunk_manager.access_chunk(chunk) self._chunk_manager.access_chunk(chunk)
# get possible chunks to prefetch # get possible chunks to prefetch
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks() chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(
is_warmup=self._gemini_manager.is_warmup(),
compute_list=self._gemini_manager.compute_list,
compute_idx=self._gemini_manager.compute_idx,
async_works=self._gemini_manager.async_works,
)
# prefetch # prefetch
for chunk in chunks_fetch_async: for chunk in chunks_fetch_async:

@ -45,7 +45,7 @@ class GeminiManager:
self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs) self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs)
self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1 self._compute_idx: int = -1
self._async_works: Dict[Chunk, dist.work] = {} self._async_works: Dict[Chunk, dist.Work] = {}
self._h2d_volume = 0 self._h2d_volume = 0
self._d2h_volume = 0 self._d2h_volume = 0
@ -183,6 +183,10 @@ class GeminiManager:
def compute_idx(self) -> int: def compute_idx(self) -> int:
return self._compute_idx return self._compute_idx
@property
def async_works(self) -> Dict[Chunk, dist.Work]:
return self._async_works
@property @property
def placement_policy(self) -> PlacementPolicy: def placement_policy(self) -> PlacementPolicy:
return self._placement_policy return self._placement_policy

@ -5,13 +5,13 @@ from time import time
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
import torch import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
from .gemini_mgr import GeminiManager
from .memory_tracer import ChunkMemStatsCollector from .memory_tracer import ChunkMemStatsCollector
@ -20,13 +20,11 @@ class PlacementPolicy(ABC):
def __init__( def __init__(
self, self,
gemini_manager: "GeminiManager", # TODO @botbw: solve circular import
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0, max_prefetch: int = 0,
**kwargs, **kwargs,
) -> None: ) -> None:
self.gemini_manager = gemini_manager
self.chunk_manager = chunk_manager self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
self.max_prefetch = max_prefetch self.max_prefetch = max_prefetch
@ -41,14 +39,15 @@ class PlacementPolicy(ABC):
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
def get_prefetch_chunks(self) -> List[Chunk]: def get_prefetch_chunks(
self, is_warmup, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
return [] # no prefetch by default return [] # no prefetch by default
class StaticPlacementPolicy(PlacementPolicy): class StaticPlacementPolicy(PlacementPolicy):
def __init__( def __init__(
self, self,
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0, max_prefetch: int = 0,
@ -57,9 +56,7 @@ class StaticPlacementPolicy(PlacementPolicy):
offload_param_frac: float = 0.0, offload_param_frac: float = 0.0,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__( super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
offload_param_frac = 0.0 offload_param_frac = 0.0
@ -110,13 +107,15 @@ class StaticPlacementPolicy(PlacementPolicy):
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
def get_prefetch_chunks(self) -> List[Chunk]: def get_prefetch_chunks(
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
if is_warmup: # no prefetch during warmup since we need compute_list
return [] return []
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) can_prefetch = self.max_prefetch - len(async_works)
prefetch = [] prefetch = []
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): for i in range(compute_idx + 1, len(compute_list)):
for chunk in self.gemini_manager.compute_list[i]: for chunk in compute_list[i]:
if len(prefetch) >= can_prefetch: if len(prefetch) >= can_prefetch:
break break
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
@ -132,7 +131,6 @@ class AutoPlacementPolicy(PlacementPolicy):
def __init__( def __init__(
self, self,
gemini_manager: GeminiManager,
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0, max_prefetch: int = 0,
@ -140,9 +138,7 @@ class AutoPlacementPolicy(PlacementPolicy):
steady_cuda_cap_ratio: float = 0.9, steady_cuda_cap_ratio: float = 0.9,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__( super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
@ -233,8 +229,10 @@ class AutoPlacementPolicy(PlacementPolicy):
else: else:
grads_device_map[p] = torch.device("cpu") grads_device_map[p] = torch.device("cpu")
def get_prefetch_chunks(self) -> List[Chunk]: def get_prefetch_chunks(
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
if is_warmup: # no prefetch during warmup since we need compute_list
return [] return []
# modified from self.evict_tensors # modified from self.evict_tensors
cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity(
@ -246,14 +244,14 @@ class AutoPlacementPolicy(PlacementPolicy):
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
prefetch_chunk_memory = 0 prefetch_chunk_memory = 0
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) can_prefetch = self.max_prefetch - len(async_works)
prefetch = [] prefetch = []
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): for i in range(compute_idx + 1, len(compute_list)):
for chunk in self.gemini_manager.compute_list[i]: for chunk in compute_list[i]:
chunk: Chunk
if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data: if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data:
break break
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
prefetch_chunk_memory += chunk.chunk_mem
prefetch.append(chunk) prefetch.append(chunk)
else: else:
continue continue

Loading…
Cancel
Save