You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/gemini/stateful_tensor_container.py

132 lines
4.9 KiB

import queue
import heapq
from abc import ABC, abstractmethod
from typing import Optional, List, Dict
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
def evict_check(st: StatefulTensor) -> bool:
if st.state is not TensorState.COMPUTE and st.device.type == 'cuda':
return True
return False
# Here ST means Stateful Tensor
class BaseSTContainer(ABC):
"""A type of container that store all potential stateful tensors which can be evicted from
CUDA. This kind of stateful tensor should satisfy two conditions. One is that it hasn't been
evicted, meaning the type of its device is CUDA, the other is that it isn't pinned in CUDA
memory, meaning its state isn't COMPUTE.
This container should get a stateful tensor when it become HOLD_LIKE from COMPUTE.
And it pops stateful tensors in function, `evict_tensors`.
In order to acquire an optimal eviction policy, users may need to offer computation step
index of each stateful tensor. So we can use a heap to maintain all potential evictable
statefule tensors. When poping, we can get the stateful tensor that used furthest in
current computation step.
"""
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
self.compute_step_dict = compute_step_dict
self.total_step = total_step
@abstractmethod
def empty(self) -> bool:
pass
@abstractmethod
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
pass
@abstractmethod
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
pass
@abstractmethod
def pop(self) -> Optional[StatefulTensor]:
pass
class QueueSTContainer(BaseSTContainer):
"""Queue type stateful tensor container. This is used in 'cpu' tensor placement policy.
It pops potential evictable stateful tensors in FIFO.
"""
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
super().__init__(compute_step_dict, total_step)
self.container = None
def empty(self) -> bool:
assert self.container is not None
return self.container.empty()
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
self.container = queue.SimpleQueue()
for stateful_tensor in stateful_tensor_list:
self.container.put(stateful_tensor)
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
self.container.put(stateful_tensor)
def pop(self) -> Optional[StatefulTensor]:
ret = None
while not self.empty():
out_tensor = self.container.get()
if evict_check(out_tensor):
ret = out_tensor
break
return ret
class HeapSTContainer(BaseSTContainer):
"""Heap type stateful tensor container. This is used in 'auto' tensor placement policy.
It pops potential evictable stateful tensors in the order of the distance between current
step and next used step.
"""
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
super().__init__(compute_step_dict, total_step)
self.container = None
def empty(self) -> bool:
assert self.container is not None
return self.container == []
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
self.container = []
for stateful_tensor in stateful_tensor_list:
# we want to pop the tensor which has the greatest next_step
# so the weight is next_step multiplied by -1
weight = -self.__get_next_compute_step(stateful_tensor, -1)
self.container.append((weight, stateful_tensor))
heapq.heapify(self.container)
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
# we want to pop the tensor which has the greatest next_step
# so the weight is next_step multiplied by -1
weight = -self.__get_next_compute_step(stateful_tensor, cur_step)
heapq.heappush(self.container, (weight, stateful_tensor))
def pop(self) -> Optional[StatefulTensor]:
ret = None
while not self.empty():
_, out_tensor = heapq.heappop(self.container)
if evict_check(out_tensor):
ret = out_tensor
break
return ret
def __get_next_compute_step(self, stateful_tensor: StatefulTensor, cur_step: int):
# compute the id of next step
# if the tensor is not used in the furture
# next_step is set to the maximum
next_step = self.total_step
step_list = self.compute_step_dict[stateful_tensor]
for step in step_list:
if step > cur_step:
next_step = step
break
return next_step