mirror of https://github.com/hpcaitech/ColossalAI
75 lines
2.3 KiB
Python
75 lines
2.3 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
|
|
from colossalai.gemini.stateful_tensor_container import QueueSTContainer, HeapSTContainer
|
|
|
|
|
|
@pytest.mark.dist
|
|
def test_stateful_tensor_container():
|
|
st1 = StatefulTensor(torch.randn(1, device='cuda'))
|
|
st2 = StatefulTensor(torch.randn(2, device='cuda'))
|
|
st3 = StatefulTensor(torch.randn(3, device='cuda'))
|
|
stateful_tensor_list = [st1, st2, st3]
|
|
step_list = [st1, st2, st3, st3, st2, st1]
|
|
|
|
compute_step_dict = dict()
|
|
compute_step_dict[st1] = [0, 5]
|
|
compute_step_dict[st2] = [1, 4]
|
|
compute_step_dict[st3] = [2, 3]
|
|
|
|
def run_queue_test():
|
|
# test queue container
|
|
queue_container = QueueSTContainer(compute_step_dict, 6)
|
|
queue_container.create(stateful_tensor_list)
|
|
|
|
res_list = []
|
|
|
|
for i in range(6):
|
|
stateful_tensor = step_list[i]
|
|
stateful_tensor.trans_state(TensorState.COMPUTE)
|
|
st_out = queue_container.pop()
|
|
st_out.move_to(torch.device('cpu'))
|
|
|
|
res_list.append(st_out.payload.size(0))
|
|
|
|
stateful_tensor.move_to(torch.device('cuda'))
|
|
queue_container.push(stateful_tensor, i)
|
|
stateful_tensor.trans_state(TensorState.HOLD)
|
|
|
|
assert res_list == [2, 3, 1, 2, 3, 2]
|
|
|
|
run_queue_test()
|
|
|
|
def run_heap_test():
|
|
# test heap container
|
|
st1.move_to(torch.device('cuda'))
|
|
st2.move_to(torch.device('cuda'))
|
|
st3.move_to(torch.device('cuda'))
|
|
|
|
heap_container = HeapSTContainer(compute_step_dict, 6)
|
|
heap_container.create(stateful_tensor_list)
|
|
|
|
res_list = []
|
|
|
|
for i in range(6):
|
|
stateful_tensor = step_list[i]
|
|
stateful_tensor.trans_state(TensorState.COMPUTE)
|
|
st_out = heap_container.pop()
|
|
|
|
if st_out is not None:
|
|
res_list.append(st_out.payload.size(0))
|
|
st_out.move_to(torch.device('cpu'))
|
|
|
|
stateful_tensor.move_to(torch.device('cuda'))
|
|
heap_container.push(stateful_tensor, i)
|
|
stateful_tensor.trans_state(TensorState.HOLD)
|
|
|
|
assert res_list == [3, 1, 2, 3, 2]
|
|
|
|
run_heap_test()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_stateful_tensor_container()
|