You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/zero/gemini/placement_policy.py

275 lines
12 KiB

import functools
import warnings
from abc import ABC, abstractmethod
from time import time
from typing import Dict, List, Optional, Tuple, Type
import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector
class PlacementPolicy(ABC):
need_mem_stats: bool = False
def __init__(
self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
**kwargs,
) -> None:
self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
self.max_prefetch = max_prefetch
@abstractmethod
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
raise NotImplementedError
@abstractmethod
def setup_grads_device(
self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device]
) -> None:
raise NotImplementedError
def get_prefetch_chunks(
self, is_warmup, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
return [] # no prefetch by default
class StaticPlacementPolicy(PlacementPolicy):
def __init__(
self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
shard_param_frac: float = 1.0,
offload_optim_frac: float = 0.0,
offload_param_frac: float = 0.0,
**kwargs,
) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
offload_param_frac = 0.0
self.shard_param_frac = shard_param_frac
self.offload_optim_frac = offload_optim_frac
self.offload_param_frac = offload_param_frac
# these should be initialized in setup_grads_device
self.keep_gathered_chunk_mem = 0.0
self.keep_cuda_chunk_mem = 0.0
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
can_shard_chunk_mem = sum(chunk.chunk_mem for chunk in can_evict_chunks)
can_offload_chunk_mem = can_shard_chunk_mem
for chunk in can_evict_chunks:
if can_shard_chunk_mem <= self.keep_gathered_chunk_mem:
break
self.chunk_manager.release_chunk(chunk)
# real saved mem is chunk_mem - shard_mem, for simplicity we use chunk_mem
can_shard_chunk_mem -= chunk.chunk_mem
for chunk in can_evict_chunks:
if can_offload_chunk_mem <= self.keep_cuda_chunk_mem:
break
self.chunk_manager.move_chunk(chunk, torch.device("cpu"))
# real saved mem is shard_mem, for simplicity we use chunk_mem
can_offload_chunk_mem -= chunk.chunk_mem
return 0, 0.0
def setup_grads_device(
self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device]
) -> None:
total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params)
offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac
offloaded_optim_chunk_mem = 0
chunks = set(self.chunk_manager.get_chunk(p) for p in params)
for chunk in chunks:
params = chunk.get_tensors()
# init offload optim settings
# keep gathered chunks are in CUDA
if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
device = get_accelerator().get_current_device()
else:
device = torch.device("cpu")
# real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
offloaded_optim_chunk_mem += chunk.chunk_mem
for p in params:
grads_device_map[p] = device
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
def get_prefetch_chunks(
self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
if is_warmup: # no prefetch during warmup since we need compute_list
return []
can_prefetch = self.max_prefetch - len(async_works)
prefetch = []
for i in range(compute_idx + 1, len(compute_list)):
for chunk in compute_list[i]:
if len(prefetch) >= can_prefetch:
break
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
prefetch.append(chunk)
else:
continue
break
return prefetch
class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True
def __init__(
self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
warmup_non_model_data_ratio: float = 0.8,
steady_cuda_cap_ratio: float = 0.9,
**kwargs,
) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
self._warmup_non_model_data_ratio = warmup_non_model_data_ratio
self._steady_cuda_cap_ratio = steady_cuda_cap_ratio
self.__avail_cuda_model_data_for_prefetch = None
def evict_tensors(
self,
can_evict_chunks: List[Chunk],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
compute_idx: int = 0,
**kwargs,
) -> Tuple[int, float]:
"""
Evict tensors from CUDA device.
Args:
can_evict_chunks (List[StatefulTensor]): the list of tensors that can be evicted.
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
compute_idx (int, optional): the idx of computing device. Defaults to 0.
Raises:
RuntimeError:
Returns:
int: the volume of memory that is evicted
"""
start = time()
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
if warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda")
cuda_capacity *= self._steady_cuda_cap_ratio
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
freed_cuda_model_data = 0
if avail_cuda_model_data < cuda_demand:
# Move cuda_demand - avail_cuda_model_data volume of tensors
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_chunks = can_evict_chunks
if not warmup:
to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
# print(self._sort_can_evict_chunks.cache_info())
for chunk in to_free_chunks:
if freed_cuda_model_data >= to_free_cuda_model_data:
break
self.chunk_manager.release_chunk(chunk)
self.chunk_manager.move_chunk(chunk, torch.device("cpu"))
freed_cuda_model_data += chunk.chunk_mem
if freed_cuda_model_data < to_free_cuda_model_data:
raise RuntimeError(
f"Adjust layout failed! No enough CUDA memory! "
f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
)
self.__avail_cuda_model_data_for_prefetch = avail_cuda_model_data + freed_cuda_model_data
return freed_cuda_model_data, time() - start
@staticmethod
@functools.lru_cache(maxsize=None)
def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
for i in range(len(compute_list) - 1, compute_idx, -1):
for chunk in compute_list[i]:
if chunk in next_compute_idx:
next_compute_idx[chunk] = i
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
def setup_grads_device(
self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device]
) -> None:
for p in params:
chunk = self.chunk_manager.get_chunk(p)
# init offload optim settings
# keep gathered chunks are in CUDA
if chunk.keep_gathered:
grads_device_map[p] = get_accelerator().get_current_device()
else:
grads_device_map[p] = torch.device("cpu")
def get_prefetch_chunks(
self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
if is_warmup: # no prefetch during warmup since we need compute_list
return []
avail_cuda_model_data = self.__avail_cuda_model_data_for_prefetch
self.__avail_cuda_model_data_for_prefetch = None # incase of double use
prefetch_chunk_memory = 0
can_prefetch = self.max_prefetch - len(async_works)
prefetch = []
for i in range(compute_idx + 1, len(compute_list)):
for chunk in compute_list[i]:
if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data:
break
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
prefetch_chunk_memory += chunk.chunk_mem
prefetch.append(chunk)
else:
continue
break
return prefetch
class PlacementPolicyFactory:
policies: Dict[str, Type[PlacementPolicy]] = {
"auto": AutoPlacementPolicy,
"static": StaticPlacementPolicy,
}
@staticmethod
def create(policy_name: str) -> Type[PlacementPolicy]:
if policy_name not in PlacementPolicyFactory.policies:
raise TypeError(f"Unknown tensor placement policy {policy_name}")
return PlacementPolicyFactory.policies[policy_name]
@staticmethod
def get_policy_names():
return tuple(PlacementPolicyFactory.policies.keys())