import functools import warnings from abc import ABC, abstractmethod from time import time from typing import Dict, List, Optional, Tuple, Type import torch import torch.distributed as dist from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector class PlacementPolicy(ABC): need_mem_stats: bool = False def __init__( self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, **kwargs, ) -> None: self.chunk_manager = chunk_manager self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector self.max_prefetch = max_prefetch @abstractmethod def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: raise NotImplementedError @abstractmethod def setup_grads_device( self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] ) -> None: raise NotImplementedError def get_prefetch_chunks( self, is_warmup, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] ) -> List[Chunk]: return [] # no prefetch by default class StaticPlacementPolicy(PlacementPolicy): def __init__( self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, shard_param_frac: float = 1.0, offload_optim_frac: float = 0.0, offload_param_frac: float = 0.0, **kwargs, ) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") offload_param_frac = 0.0 self.shard_param_frac = shard_param_frac self.offload_optim_frac = offload_optim_frac self.offload_param_frac = offload_param_frac # these should be initialized in setup_grads_device self.keep_gathered_chunk_mem = 0.0 self.keep_cuda_chunk_mem = 0.0 def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: can_shard_chunk_mem = sum(chunk.chunk_mem for chunk in can_evict_chunks) can_offload_chunk_mem = can_shard_chunk_mem for chunk in can_evict_chunks: if can_shard_chunk_mem <= self.keep_gathered_chunk_mem: break self.chunk_manager.release_chunk(chunk) # real saved mem is chunk_mem - shard_mem, for simplicity we use chunk_mem can_shard_chunk_mem -= chunk.chunk_mem for chunk in can_evict_chunks: if can_offload_chunk_mem <= self.keep_cuda_chunk_mem: break self.chunk_manager.move_chunk(chunk, torch.device("cpu")) # real saved mem is shard_mem, for simplicity we use chunk_mem can_offload_chunk_mem -= chunk.chunk_mem return 0, 0.0 def setup_grads_device( self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] ) -> None: total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params) offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac offloaded_optim_chunk_mem = 0 chunks = set(self.chunk_manager.get_chunk(p) for p in params) for chunk in chunks: params = chunk.get_tensors() # init offload optim settings # keep gathered chunks are in CUDA if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem: device = get_accelerator().get_current_device() else: device = torch.device("cpu") # real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here offloaded_optim_chunk_mem += chunk.chunk_mem for p in params: grads_device_map[p] = device self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) def get_prefetch_chunks( self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] ) -> List[Chunk]: if is_warmup: # no prefetch during warmup since we need compute_list return [] can_prefetch = self.max_prefetch - len(async_works) prefetch = [] for i in range(compute_idx + 1, len(compute_list)): for chunk in compute_list[i]: if len(prefetch) >= can_prefetch: break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: prefetch.append(chunk) else: continue break return prefetch class AutoPlacementPolicy(PlacementPolicy): need_mem_stats: bool = True def __init__( self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, warmup_non_model_data_ratio: float = 0.8, steady_cuda_cap_ratio: float = 0.9, **kwargs, ) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() self._warmup_non_model_data_ratio = warmup_non_model_data_ratio self._steady_cuda_cap_ratio = steady_cuda_cap_ratio def evict_tensors( self, can_evict_chunks: List[Chunk], cuda_demand: int = 0, warmup: bool = True, compute_list: Optional[List[Tuple[Chunk, ...]]] = None, compute_idx: int = 0, **kwargs, ) -> Tuple[int, float]: """ Evict tensors from CUDA device. Args: can_evict_chunks (List[StatefulTensor]): the list of tensors that can be evicted. cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0. warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True. compute_list (List[StatefulTensor], optional): TODO. Defaults to []. compute_idx (int, optional): the idx of computing device. Defaults to 0. Raises: RuntimeError: Returns: int: the volume of memory that is evicted """ start = time() cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) used_cuda_model_data = self.chunk_manager.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio else: # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") cuda_capacity *= self._steady_cuda_cap_ratio total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data freed_cuda_model_data = 0 if avail_cuda_model_data < cuda_demand: # Move cuda_demand - avail_cuda_model_data volume of tensors # to_free_cuda_model_data = cuda_demand - avail_cuda_model_data to_free_cuda_model_data = cuda_demand - avail_cuda_model_data to_free_chunks = can_evict_chunks if not warmup: to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list)) # print(self._sort_can_evict_chunks.cache_info()) for chunk in to_free_chunks: if freed_cuda_model_data >= to_free_cuda_model_data: break self.chunk_manager.release_chunk(chunk) self.chunk_manager.move_chunk(chunk, torch.device("cpu")) freed_cuda_model_data += chunk.chunk_mem if freed_cuda_model_data < to_free_cuda_model_data: raise RuntimeError( f"Adjust layout failed! No enough CUDA memory! " f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" ) return freed_cuda_model_data, time() - start @staticmethod @functools.lru_cache(maxsize=None) def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list: next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks} for i in range(len(compute_list) - 1, compute_idx, -1): for chunk in compute_list[i]: if chunk in next_compute_idx: next_compute_idx[chunk] = i next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) return [t for (t, idx) in next_compute_idx] def setup_grads_device( self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] ) -> None: for p in params: chunk = self.chunk_manager.get_chunk(p) # init offload optim settings # keep gathered chunks are in CUDA if chunk.keep_gathered: grads_device_map[p] = get_accelerator().get_current_device() else: grads_device_map[p] = torch.device("cpu") def get_prefetch_chunks( self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] ) -> List[Chunk]: if is_warmup: # no prefetch during warmup since we need compute_list return [] # modified from self.evict_tensors cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( get_accelerator().get_current_device() ) max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") used_cuda_model_data = self.chunk_manager.total_mem["cuda"] total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data prefetch_chunk_memory = 0 can_prefetch = self.max_prefetch - len(async_works) prefetch = [] for i in range(compute_idx + 1, len(compute_list)): for chunk in compute_list[i]: if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data: break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: prefetch_chunk_memory += chunk.chunk_mem prefetch.append(chunk) else: continue break return prefetch class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { "auto": AutoPlacementPolicy, "static": StaticPlacementPolicy, } @staticmethod def create(policy_name: str) -> Type[PlacementPolicy]: if policy_name not in PlacementPolicyFactory.policies: raise TypeError(f"Unknown tensor placement policy {policy_name}") return PlacementPolicyFactory.policies[policy_name] @staticmethod def get_policy_names(): return tuple(PlacementPolicyFactory.policies.keys())