mirror of https://github.com/hpcaitech/ColossalAI
[gemini] add stateful tensor container (#867)
parent
d01d3b8cb0
commit
3107817172
|
@ -0,0 +1,131 @@
|
||||||
|
import queue
|
||||||
|
import heapq
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, List, Dict
|
||||||
|
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||||
|
|
||||||
|
|
||||||
|
def evict_check(st: StatefulTensor) -> bool:
|
||||||
|
if st.state is not TensorState.COMPUTE and st.device.type == 'cuda':
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Here ST means Stateful Tensor
|
||||||
|
class BaseSTContainer(ABC):
|
||||||
|
"""A type of container that store all potential stateful tensors which can be evicted from
|
||||||
|
CUDA. This kind of stateful tensor should satisfy two conditions. One is that it hasn't been
|
||||||
|
evicted, meaning the type of its device is CUDA, the other is that it isn't pinned in CUDA
|
||||||
|
memory, meaning its state isn't COMPUTE.
|
||||||
|
|
||||||
|
This container should get a stateful tensor when it become HOLD_LIKE from COMPUTE.
|
||||||
|
And it pops stateful tensors in function, `evict_tensors`.
|
||||||
|
|
||||||
|
In order to acquire an optimal eviction policy, users may need to offer computation step
|
||||||
|
index of each stateful tensor. So we can use a heap to maintain all potential evictable
|
||||||
|
statefule tensors. When poping, we can get the stateful tensor that used furthest in
|
||||||
|
current computation step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
|
||||||
|
self.compute_step_dict = compute_step_dict
|
||||||
|
self.total_step = total_step
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def empty(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pop(self) -> Optional[StatefulTensor]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QueueSTContainer(BaseSTContainer):
|
||||||
|
"""Queue type stateful tensor container. This is used in 'cpu' tensor placement policy.
|
||||||
|
It pops potential evictable stateful tensors in FIFO.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
|
||||||
|
super().__init__(compute_step_dict, total_step)
|
||||||
|
self.container = None
|
||||||
|
|
||||||
|
def empty(self) -> bool:
|
||||||
|
assert self.container is not None
|
||||||
|
return self.container.empty()
|
||||||
|
|
||||||
|
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
|
||||||
|
self.container = queue.SimpleQueue()
|
||||||
|
for stateful_tensor in stateful_tensor_list:
|
||||||
|
self.container.put(stateful_tensor)
|
||||||
|
|
||||||
|
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
|
||||||
|
self.container.put(stateful_tensor)
|
||||||
|
|
||||||
|
def pop(self) -> Optional[StatefulTensor]:
|
||||||
|
ret = None
|
||||||
|
while not self.empty():
|
||||||
|
out_tensor = self.container.get()
|
||||||
|
if evict_check(out_tensor):
|
||||||
|
ret = out_tensor
|
||||||
|
break
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class HeapSTContainer(BaseSTContainer):
|
||||||
|
"""Heap type stateful tensor container. This is used in 'auto' tensor placement policy.
|
||||||
|
It pops potential evictable stateful tensors in the order of the distance between current
|
||||||
|
step and next used step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
|
||||||
|
super().__init__(compute_step_dict, total_step)
|
||||||
|
self.container = None
|
||||||
|
|
||||||
|
def empty(self) -> bool:
|
||||||
|
assert self.container is not None
|
||||||
|
return self.container == []
|
||||||
|
|
||||||
|
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
|
||||||
|
self.container = []
|
||||||
|
for stateful_tensor in stateful_tensor_list:
|
||||||
|
# we want to pop the tensor which has the greatest next_step
|
||||||
|
# so the weight is next_step multiplied by -1
|
||||||
|
weight = -self.__get_next_compute_step(stateful_tensor, -1)
|
||||||
|
self.container.append((weight, stateful_tensor))
|
||||||
|
heapq.heapify(self.container)
|
||||||
|
|
||||||
|
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
|
||||||
|
# we want to pop the tensor which has the greatest next_step
|
||||||
|
# so the weight is next_step multiplied by -1
|
||||||
|
weight = -self.__get_next_compute_step(stateful_tensor, cur_step)
|
||||||
|
heapq.heappush(self.container, (weight, stateful_tensor))
|
||||||
|
|
||||||
|
def pop(self) -> Optional[StatefulTensor]:
|
||||||
|
ret = None
|
||||||
|
while not self.empty():
|
||||||
|
_, out_tensor = heapq.heappop(self.container)
|
||||||
|
if evict_check(out_tensor):
|
||||||
|
ret = out_tensor
|
||||||
|
break
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __get_next_compute_step(self, stateful_tensor: StatefulTensor, cur_step: int):
|
||||||
|
# compute the id of next step
|
||||||
|
# if the tensor is not used in the furture
|
||||||
|
# next_step is set to the maximum
|
||||||
|
next_step = self.total_step
|
||||||
|
step_list = self.compute_step_dict[stateful_tensor]
|
||||||
|
for step in step_list:
|
||||||
|
if step > cur_step:
|
||||||
|
next_step = step
|
||||||
|
break
|
||||||
|
return next_step
|
|
@ -0,0 +1,74 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
|
||||||
|
from colossalai.gemini.stateful_tensor_container import QueueSTContainer, HeapSTContainer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
def test_stateful_tensor_container():
|
||||||
|
st1 = StatefulTensor(torch.randn(1, device='cuda'))
|
||||||
|
st2 = StatefulTensor(torch.randn(2, device='cuda'))
|
||||||
|
st3 = StatefulTensor(torch.randn(3, device='cuda'))
|
||||||
|
stateful_tensor_list = [st1, st2, st3]
|
||||||
|
step_list = [st1, st2, st3, st3, st2, st1]
|
||||||
|
|
||||||
|
compute_step_dict = dict()
|
||||||
|
compute_step_dict[st1] = [0, 5]
|
||||||
|
compute_step_dict[st2] = [1, 4]
|
||||||
|
compute_step_dict[st3] = [2, 3]
|
||||||
|
|
||||||
|
def run_queue_test():
|
||||||
|
# test queue container
|
||||||
|
queue_container = QueueSTContainer(compute_step_dict, 6)
|
||||||
|
queue_container.create(stateful_tensor_list)
|
||||||
|
|
||||||
|
res_list = []
|
||||||
|
|
||||||
|
for i in range(6):
|
||||||
|
stateful_tensor = step_list[i]
|
||||||
|
stateful_tensor.trans_state(TensorState.COMPUTE)
|
||||||
|
st_out = queue_container.pop()
|
||||||
|
st_out.move_to(torch.device('cpu'))
|
||||||
|
|
||||||
|
res_list.append(st_out.payload.size(0))
|
||||||
|
|
||||||
|
stateful_tensor.move_to(torch.device('cuda'))
|
||||||
|
queue_container.push(stateful_tensor, i)
|
||||||
|
stateful_tensor.trans_state(TensorState.HOLD)
|
||||||
|
|
||||||
|
assert res_list == [2, 3, 1, 2, 3, 2]
|
||||||
|
|
||||||
|
run_queue_test()
|
||||||
|
|
||||||
|
def run_heap_test():
|
||||||
|
# test heap container
|
||||||
|
st1.move_to(torch.device('cuda'))
|
||||||
|
st2.move_to(torch.device('cuda'))
|
||||||
|
st3.move_to(torch.device('cuda'))
|
||||||
|
|
||||||
|
heap_container = HeapSTContainer(compute_step_dict, 6)
|
||||||
|
heap_container.create(stateful_tensor_list)
|
||||||
|
|
||||||
|
res_list = []
|
||||||
|
|
||||||
|
for i in range(6):
|
||||||
|
stateful_tensor = step_list[i]
|
||||||
|
stateful_tensor.trans_state(TensorState.COMPUTE)
|
||||||
|
st_out = heap_container.pop()
|
||||||
|
|
||||||
|
if st_out is not None:
|
||||||
|
res_list.append(st_out.payload.size(0))
|
||||||
|
st_out.move_to(torch.device('cpu'))
|
||||||
|
|
||||||
|
stateful_tensor.move_to(torch.device('cuda'))
|
||||||
|
heap_container.push(stateful_tensor, i)
|
||||||
|
stateful_tensor.trans_state(TensorState.HOLD)
|
||||||
|
|
||||||
|
assert res_list == [3, 1, 2, 3, 2]
|
||||||
|
|
||||||
|
run_heap_test()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_stateful_tensor_container()
|
Loading…
Reference in New Issue