mirror of https://github.com/hpcaitech/ColossalAI
[gemini] add stateful tensor container (#867)
parent
d01d3b8cb0
commit
3107817172
@ -0,0 +1,131 @@
|
||||
import queue
|
||||
import heapq
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List, Dict
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
|
||||
|
||||
def evict_check(st: StatefulTensor) -> bool:
|
||||
if st.state is not TensorState.COMPUTE and st.device.type == 'cuda':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Here ST means Stateful Tensor
|
||||
class BaseSTContainer(ABC):
|
||||
"""A type of container that store all potential stateful tensors which can be evicted from
|
||||
CUDA. This kind of stateful tensor should satisfy two conditions. One is that it hasn't been
|
||||
evicted, meaning the type of its device is CUDA, the other is that it isn't pinned in CUDA
|
||||
memory, meaning its state isn't COMPUTE.
|
||||
|
||||
This container should get a stateful tensor when it become HOLD_LIKE from COMPUTE.
|
||||
And it pops stateful tensors in function, `evict_tensors`.
|
||||
|
||||
In order to acquire an optimal eviction policy, users may need to offer computation step
|
||||
index of each stateful tensor. So we can use a heap to maintain all potential evictable
|
||||
statefule tensors. When poping, we can get the stateful tensor that used furthest in
|
||||
current computation step.
|
||||
"""
|
||||
|
||||
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
|
||||
self.compute_step_dict = compute_step_dict
|
||||
self.total_step = total_step
|
||||
|
||||
@abstractmethod
|
||||
def empty(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pop(self) -> Optional[StatefulTensor]:
|
||||
pass
|
||||
|
||||
|
||||
class QueueSTContainer(BaseSTContainer):
|
||||
"""Queue type stateful tensor container. This is used in 'cpu' tensor placement policy.
|
||||
It pops potential evictable stateful tensors in FIFO.
|
||||
"""
|
||||
|
||||
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
|
||||
super().__init__(compute_step_dict, total_step)
|
||||
self.container = None
|
||||
|
||||
def empty(self) -> bool:
|
||||
assert self.container is not None
|
||||
return self.container.empty()
|
||||
|
||||
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
|
||||
self.container = queue.SimpleQueue()
|
||||
for stateful_tensor in stateful_tensor_list:
|
||||
self.container.put(stateful_tensor)
|
||||
|
||||
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
|
||||
self.container.put(stateful_tensor)
|
||||
|
||||
def pop(self) -> Optional[StatefulTensor]:
|
||||
ret = None
|
||||
while not self.empty():
|
||||
out_tensor = self.container.get()
|
||||
if evict_check(out_tensor):
|
||||
ret = out_tensor
|
||||
break
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class HeapSTContainer(BaseSTContainer):
|
||||
"""Heap type stateful tensor container. This is used in 'auto' tensor placement policy.
|
||||
It pops potential evictable stateful tensors in the order of the distance between current
|
||||
step and next used step.
|
||||
"""
|
||||
|
||||
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
|
||||
super().__init__(compute_step_dict, total_step)
|
||||
self.container = None
|
||||
|
||||
def empty(self) -> bool:
|
||||
assert self.container is not None
|
||||
return self.container == []
|
||||
|
||||
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
|
||||
self.container = []
|
||||
for stateful_tensor in stateful_tensor_list:
|
||||
# we want to pop the tensor which has the greatest next_step
|
||||
# so the weight is next_step multiplied by -1
|
||||
weight = -self.__get_next_compute_step(stateful_tensor, -1)
|
||||
self.container.append((weight, stateful_tensor))
|
||||
heapq.heapify(self.container)
|
||||
|
||||
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
|
||||
# we want to pop the tensor which has the greatest next_step
|
||||
# so the weight is next_step multiplied by -1
|
||||
weight = -self.__get_next_compute_step(stateful_tensor, cur_step)
|
||||
heapq.heappush(self.container, (weight, stateful_tensor))
|
||||
|
||||
def pop(self) -> Optional[StatefulTensor]:
|
||||
ret = None
|
||||
while not self.empty():
|
||||
_, out_tensor = heapq.heappop(self.container)
|
||||
if evict_check(out_tensor):
|
||||
ret = out_tensor
|
||||
break
|
||||
return ret
|
||||
|
||||
def __get_next_compute_step(self, stateful_tensor: StatefulTensor, cur_step: int):
|
||||
# compute the id of next step
|
||||
# if the tensor is not used in the furture
|
||||
# next_step is set to the maximum
|
||||
next_step = self.total_step
|
||||
step_list = self.compute_step_dict[stateful_tensor]
|
||||
for step in step_list:
|
||||
if step > cur_step:
|
||||
next_step = step
|
||||
break
|
||||
return next_step
|
@ -0,0 +1,74 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
|
||||
from colossalai.gemini.stateful_tensor_container import QueueSTContainer, HeapSTContainer
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_stateful_tensor_container():
|
||||
st1 = StatefulTensor(torch.randn(1, device='cuda'))
|
||||
st2 = StatefulTensor(torch.randn(2, device='cuda'))
|
||||
st3 = StatefulTensor(torch.randn(3, device='cuda'))
|
||||
stateful_tensor_list = [st1, st2, st3]
|
||||
step_list = [st1, st2, st3, st3, st2, st1]
|
||||
|
||||
compute_step_dict = dict()
|
||||
compute_step_dict[st1] = [0, 5]
|
||||
compute_step_dict[st2] = [1, 4]
|
||||
compute_step_dict[st3] = [2, 3]
|
||||
|
||||
def run_queue_test():
|
||||
# test queue container
|
||||
queue_container = QueueSTContainer(compute_step_dict, 6)
|
||||
queue_container.create(stateful_tensor_list)
|
||||
|
||||
res_list = []
|
||||
|
||||
for i in range(6):
|
||||
stateful_tensor = step_list[i]
|
||||
stateful_tensor.trans_state(TensorState.COMPUTE)
|
||||
st_out = queue_container.pop()
|
||||
st_out.move_to(torch.device('cpu'))
|
||||
|
||||
res_list.append(st_out.payload.size(0))
|
||||
|
||||
stateful_tensor.move_to(torch.device('cuda'))
|
||||
queue_container.push(stateful_tensor, i)
|
||||
stateful_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
assert res_list == [2, 3, 1, 2, 3, 2]
|
||||
|
||||
run_queue_test()
|
||||
|
||||
def run_heap_test():
|
||||
# test heap container
|
||||
st1.move_to(torch.device('cuda'))
|
||||
st2.move_to(torch.device('cuda'))
|
||||
st3.move_to(torch.device('cuda'))
|
||||
|
||||
heap_container = HeapSTContainer(compute_step_dict, 6)
|
||||
heap_container.create(stateful_tensor_list)
|
||||
|
||||
res_list = []
|
||||
|
||||
for i in range(6):
|
||||
stateful_tensor = step_list[i]
|
||||
stateful_tensor.trans_state(TensorState.COMPUTE)
|
||||
st_out = heap_container.pop()
|
||||
|
||||
if st_out is not None:
|
||||
res_list.append(st_out.payload.size(0))
|
||||
st_out.move_to(torch.device('cpu'))
|
||||
|
||||
stateful_tensor.move_to(torch.device('cuda'))
|
||||
heap_container.push(stateful_tensor, i)
|
||||
stateful_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
assert res_list == [3, 1, 2, 3, 2]
|
||||
|
||||
run_heap_test()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_stateful_tensor_container()
|
Loading…
Reference in new issue