You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_gemini/test_stateful_tensor_contai...

75 lines
2.3 KiB

import pytest
import torch
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
from colossalai.gemini.stateful_tensor_container import QueueSTContainer, HeapSTContainer
@pytest.mark.dist
def test_stateful_tensor_container():
st1 = StatefulTensor(torch.randn(1, device='cuda'))
st2 = StatefulTensor(torch.randn(2, device='cuda'))
st3 = StatefulTensor(torch.randn(3, device='cuda'))
stateful_tensor_list = [st1, st2, st3]
step_list = [st1, st2, st3, st3, st2, st1]
compute_step_dict = dict()
compute_step_dict[st1] = [0, 5]
compute_step_dict[st2] = [1, 4]
compute_step_dict[st3] = [2, 3]
def run_queue_test():
# test queue container
queue_container = QueueSTContainer(compute_step_dict, 6)
queue_container.create(stateful_tensor_list)
res_list = []
for i in range(6):
stateful_tensor = step_list[i]
stateful_tensor.trans_state(TensorState.COMPUTE)
st_out = queue_container.pop()
st_out.move_to(torch.device('cpu'))
res_list.append(st_out.payload.size(0))
stateful_tensor.move_to(torch.device('cuda'))
queue_container.push(stateful_tensor, i)
stateful_tensor.trans_state(TensorState.HOLD)
assert res_list == [2, 3, 1, 2, 3, 2]
run_queue_test()
def run_heap_test():
# test heap container
st1.move_to(torch.device('cuda'))
st2.move_to(torch.device('cuda'))
st3.move_to(torch.device('cuda'))
heap_container = HeapSTContainer(compute_step_dict, 6)
heap_container.create(stateful_tensor_list)
res_list = []
for i in range(6):
stateful_tensor = step_list[i]
stateful_tensor.trans_state(TensorState.COMPUTE)
st_out = heap_container.pop()
if st_out is not None:
res_list.append(st_out.payload.size(0))
st_out.move_to(torch.device('cpu'))
stateful_tensor.move_to(torch.device('cuda'))
heap_container.push(stateful_tensor, i)
stateful_tensor.trans_state(TensorState.HOLD)
assert res_list == [3, 1, 2, 3, 2]
run_heap_test()
if __name__ == '__main__':
test_stateful_tensor_container()